Skip to content

Commit 6220198

Browse files
committed
Make SSPIContextProvider public
1 parent f40acdc commit 6220198

File tree

12 files changed

+95
-32
lines changed

12 files changed

+95
-32
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnection.cs

+23-9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
8888
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;
8989

9090
private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
91+
private Func<SSPIContextProvider> _sspiContextProviderFactory;
9192

9293
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
9394
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
@@ -646,7 +647,7 @@ public override string ConnectionString
646647
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
647648
}
648649
}
649-
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback));
650+
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _accessTokenCallback, _sspiContextProviderFactory));
650651
_connectionString = value; // Change _connectionString value only after value is validated
651652
CacheConnectionStringProperties();
652653
}
@@ -706,7 +707,7 @@ public string AccessToken
706707
}
707708

708709
// Need to call ConnectionString_Set to do proper pool group check
709-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null));
710+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: value, accessTokenCallback: null, sspiContextProviderFactory: _sspiContextProviderFactory));
710711
_accessToken = value;
711712
}
712713
}
@@ -729,11 +730,24 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
729730
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
730731
}
731732

732-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value));
733+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: value, sspiContextProviderFactory: _sspiContextProviderFactory));
733734
_accessTokenCallback = value;
734735
}
735736
}
736737

738+
/// <summary>
739+
/// Gets or sets a <see cref="SSPIContextProvider"/>.
740+
/// </summary>
741+
public Func<SSPIContextProvider> SSPIContextProviderFactory
742+
{
743+
get { return _sspiContextProviderFactory; }
744+
set
745+
{
746+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, credential: _credential, accessToken: null, accessTokenCallback: _accessTokenCallback, sspiContextProviderFactory: value));
747+
_sspiContextProviderFactory = value;
748+
}
749+
}
750+
737751
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/Database/*' />
738752
[ResDescription(StringsHelper.ResourceNames.SqlConnection_Database)]
739753
[ResCategory(StringsHelper.ResourceNames.SqlConnection_DataSource)]
@@ -1028,7 +1042,7 @@ public SqlCredential Credential
10281042
_credential = value;
10291043

10301044
// Need to call ConnectionString_Set to do proper pool group check
1031-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback));
1045+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, accessToken: _accessToken, accessTokenCallback: _accessTokenCallback, _sspiContextProviderFactory));
10321046
}
10331047
}
10341048

@@ -1076,7 +1090,7 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessToken(S
10761090
throw ADP.InvalidMixedUsageOfCredentialAndAccessToken();
10771091
}
10781092

1079-
if(_accessTokenCallback != null)
1093+
if (_accessTokenCallback != null)
10801094
{
10811095
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
10821096
}
@@ -1098,7 +1112,7 @@ private void CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCa
10981112
throw ADP.InvalidMixedUsageOfAccessTokenCallbackAndAuthentication();
10991113
}
11001114

1101-
if(_accessToken != null)
1115+
if (_accessToken != null)
11021116
{
11031117
throw ADP.InvalidMixedUsageOfAccessTokenAndTokenCallback();
11041118
}
@@ -2212,7 +2226,7 @@ public static void ChangePassword(string connectionString, string newPassword)
22122226
throw ADP.InvalidArgumentLength(nameof(newPassword), TdsEnums.MAXLEN_NEWPASSWORD);
22132227
}
22142228

2215-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null);
2229+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, accessTokenCallback: null, sspiContextProviderFactory: null);
22162230

22172231
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
22182232
if (connectionOptions.IntegratedSecurity)
@@ -2261,7 +2275,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
22612275
throw ADP.InvalidArgumentLength(nameof(newSecurePassword), TdsEnums.MAXLEN_NEWPASSWORD);
22622276
}
22632277

2264-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
2278+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProviderFactory: null);
22652279

22662280
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
22672281

@@ -2300,7 +2314,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
23002314
if (con != null)
23012315
con.Dispose();
23022316
}
2303-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null);
2317+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, accessTokenCallback: null, sspiContextProviderFactory: null);
23042318

23052319
SqlConnectionFactory.SingletonInstance.ClearPool(key);
23062320
}

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlConnectionFactory.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
9696
// This first connection is established to SqlExpress to get the instance name
9797
// of the UserInstance.
9898
SqlConnectionString sseopt = new SqlConnectionString(opt, opt.DataSource, userInstance: true, setEnlistValue: false);
99-
sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling);
99+
sseConnection = new SqlInternalConnectionTds(identity, sseopt, key.Credential, null, "", null, false, applyTransientFaultHandling: applyTransientFaultHandling, sspiContextProviderFactory: key.SSPIContextProviderFactory);
100100
// NOTE: Retrieve <UserInstanceName> here. This user instance name will be used below to connect to the Sql Express User Instance.
101101
instanceName = sseConnection.InstanceName;
102102

@@ -133,7 +133,7 @@ override protected DbConnectionInternal CreateConnection(DbConnectionOptions opt
133133
opt = new SqlConnectionString(opt, instanceName, userInstance: false, setEnlistValue: null);
134134
poolGroupProviderInfo = null; // null so we do not pass to constructor below...
135135
}
136-
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback);
136+
return new SqlInternalConnectionTds(identity, opt, key.Credential, poolGroupProviderInfo, "", null, redirectedUserInstance, userOpt, recoverySessionData, applyTransientFaultHandling: applyTransientFaultHandling, key.AccessToken, pool, key.AccessTokenCallback, key.SSPIContextProviderFactory);
137137
}
138138

139139
protected override DbConnectionOptions CreateConnectionOptions(string connectionString, DbConnectionOptions previous)

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlInternalConnectionTds.cs

+5-3
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,8 @@ internal sealed class SqlInternalConnectionTds : SqlInternalConnection, IDisposa
130130
// The Federated Authentication returned by TryGetFedAuthTokenLocked or GetFedAuthToken.
131131
SqlFedAuthToken _fedAuthToken = null;
132132
internal byte[] _accessTokenInBytes;
133-
internal readonly Func<SqlAuthenticationParameters, CancellationToken,Task<SqlAuthenticationToken>> _accessTokenCallback;
133+
internal readonly Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
134+
internal readonly Func<SSPIContextProvider> _sspiContextProviderFactory;
134135

135136
private readonly ActiveDirectoryAuthenticationTimeoutRetryHelper _activeDirectoryAuthTimeoutRetryHelper;
136137
private readonly SqlAuthenticationProviderManager _sqlAuthenticationProviderManager;
@@ -447,8 +448,8 @@ internal SqlInternalConnectionTds(
447448
bool applyTransientFaultHandling = false,
448449
string accessToken = null,
449450
DbConnectionPool pool = null,
450-
Func<SqlAuthenticationParameters, CancellationToken,
451-
Task<SqlAuthenticationToken>> accessTokenCallback = null) : base(connectionOptions)
451+
Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> accessTokenCallback = null,
452+
Func<SSPIContextProvider> sspiContextProviderFactory = null) : base(connectionOptions)
452453

453454
{
454455
#if DEBUG
@@ -482,6 +483,7 @@ internal SqlInternalConnectionTds(
482483
}
483484

484485
_accessTokenCallback = accessTokenCallback;
486+
_sspiContextProviderFactory = sspiContextProviderFactory;
485487

486488
_activeDirectoryAuthTimeoutRetryHelper = new ActiveDirectoryAuthenticationTimeoutRetryHelper();
487489
_sqlAuthenticationProviderManager = SqlAuthenticationProviderManager.Instance;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ internal void Connect(
422422
// AD Integrated behaves like Windows integrated when connecting to a non-fedAuth server
423423
if (integratedSecurity || authType == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
424424
{
425-
_authenticationProvider = _physicalStateObj.CreateSSPIContextProvider();
425+
_authenticationProvider = Connection._sspiContextProviderFactory?.Invoke() ?? _physicalStateObj.CreateSSPIContextProvider();
426426
SqlClientEventSource.Log.TryTraceEvent("TdsParser.Connect | SEC | SSPI or Active Directory Authentication Library loaded for SQL Server based integrated authentication");
427427
}
428428

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/SqlConnection.cs

+22-7
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ private static Dictionary<string, SqlColumnEncryptionKeyStoreProvider> s_systemC
7373

7474
private Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticationToken>> _accessTokenCallback;
7575

76+
private Func<SSPIContextProvider> _sspiContextProviderFactory;
77+
7678
internal bool HasColumnEncryptionKeyStoreProvidersRegistered =>
7779
_customColumnEncryptionKeyStoreProviders is not null && _customColumnEncryptionKeyStoreProviders.Count > 0;
7880

@@ -751,7 +753,7 @@ public string AccessToken
751753

752754
_accessToken = value;
753755
// Need to call ConnectionString_Set to do proper pool group check
754-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, null));
756+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, null, _sspiContextProviderFactory));
755757
}
756758
}
757759

@@ -773,11 +775,24 @@ public Func<SqlAuthenticationParameters, CancellationToken, Task<SqlAuthenticati
773775
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback((SqlConnectionString)ConnectionOptions);
774776
}
775777

776-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, value));
778+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, value, _sspiContextProviderFactory));
777779
_accessTokenCallback = value;
778780
}
779781
}
780782

783+
/// <summary>
784+
/// Gets or sets a <see cref="SSPIContextProvider"/>.
785+
/// </summary>
786+
public Func<SSPIContextProvider> SSPIContextProviderFactory
787+
{
788+
get { return _sspiContextProviderFactory; }
789+
set
790+
{
791+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, null, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback, value));
792+
_sspiContextProviderFactory = value;
793+
}
794+
}
795+
781796
/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/CommandTimeout/*' />
782797
[
783798
DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden),
@@ -862,7 +877,7 @@ override public string ConnectionString
862877
CheckAndThrowOnInvalidCombinationOfConnectionOptionAndAccessTokenCallback(connectionOptions);
863878
}
864879
}
865-
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback));
880+
ConnectionString_Set(new SqlConnectionPoolKey(value, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback, _sspiContextProviderFactory));
866881
_connectionString = value; // Change _connectionString value only after value is validated
867882
CacheConnectionStringProperties();
868883
}
@@ -1213,7 +1228,7 @@ public SqlCredential Credential
12131228
_credential = value;
12141229

12151230
// Need to call ConnectionString_Set to do proper pool group check
1216-
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback));
1231+
ConnectionString_Set(new SqlConnectionPoolKey(_connectionString, _credential, _accessToken, _serverCertificateValidationCallback, _clientCertificateRetrievalCallback, _originalNetworkAddressInfo, _accessTokenCallback, _sspiContextProviderFactory));
12171232
}
12181233
}
12191234

@@ -2775,7 +2790,7 @@ public static void ChangePassword(string connectionString, string newPassword)
27752790
throw ADP.InvalidArgumentLength("newPassword", TdsEnums.MAXLEN_NEWPASSWORD);
27762791
}
27772792

2778-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null);
2793+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential: null, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null, sspiContextProviderFactory: null);
27792794

27802795
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
27812796
if (connectionOptions.IntegratedSecurity || connectionOptions.Authentication == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
@@ -2831,7 +2846,7 @@ public static void ChangePassword(string connectionString, SqlCredential credent
28312846
throw ADP.InvalidArgumentLength("newSecurePassword", TdsEnums.MAXLEN_NEWPASSWORD);
28322847
}
28332848

2834-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null);
2849+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null, sspiContextProviderFactory: null);
28352850

28362851
SqlConnectionString connectionOptions = SqlConnectionFactory.FindSqlConnectionOptions(key);
28372852

@@ -2876,7 +2891,7 @@ private static void ChangePassword(string connectionString, SqlConnectionString
28762891
throw SQL.ChangePasswordRequires2005();
28772892
}
28782893
}
2879-
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null);
2894+
SqlConnectionPoolKey key = new SqlConnectionPoolKey(connectionString, credential, accessToken: null, serverCertificateValidationCallback: null, clientCertificateRetrievalCallback: null, originalNetworkAddressInfo: null, accessTokenCallback: null, sspiContextProviderFactory: null);
28802895

28812896
SqlConnectionFactory.SingletonInstance.ClearPool(key);
28822897
}

0 commit comments

Comments
 (0)