Skip to content

Commit d7bdc19

Browse files
authored
Expose MultiShardConnection with AccessToken support (#201)
* Refresh ShardConnections access tokens * Expose MultiShardConnection with SqlConnectionInfo
1 parent 209a1b8 commit d7bdc19

File tree

5 files changed

+140
-61
lines changed

5 files changed

+140
-61
lines changed

Src/ElasticScale.Client/Query/MultiShardCommand.cs

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -447,26 +447,12 @@ internal MultiShardDataReader ExecuteReader(
447447
{
448448
Contract.Requires(commandRetryPolicy != null && connectionRetryPolicy != null);
449449

450-
try
451-
{
452-
return this.ExecuteReaderAsync(
453-
behavior,
454-
CancellationToken.None,
455-
commandRetryPolicy,
456-
connectionRetryPolicy,
457-
executionPolicy).Result;
458-
}
459-
catch (Exception ex)
460-
{
461-
AggregateException aex = ex as AggregateException;
462-
463-
if (null != aex)
464-
{
465-
throw aex.Flatten().InnerException;
466-
}
467-
468-
throw;
469-
}
450+
return this.ExecuteReaderAsync(
451+
behavior,
452+
CancellationToken.None,
453+
commandRetryPolicy,
454+
connectionRetryPolicy,
455+
executionPolicy).GetAwaiter().GetResult();
470456
}
471457

472458
#endregion
@@ -1235,7 +1221,7 @@ private bool IsExecutionInProgress()
12351221
private List<Tuple<ShardLocation, DbCommand>> GetShardDbCommands()
12361222
{
12371223
return this.Connection
1238-
.ShardConnections
1224+
.GetShardConnections()
12391225
.Select(sc => new Tuple<ShardLocation, DbCommand>(sc.Item1, MultiShardUtils.CloneDbCommand(_dbCommand, sc.Item2)))
12401226
.ToList();
12411227
}

Src/ElasticScale.Client/Query/MultiShardConnection.cs

Lines changed: 97 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ public sealed class MultiShardConnection : IDisposable
4949
/// </summary>
5050
private SqlConnectionInfo _connectionInfo;
5151

52-
#endregion
52+
/// <summary>
53+
/// The shard connections
54+
/// </summary>
55+
private List<Tuple<ShardLocation, DbConnection>> _shardConnections;
56+
57+
#endregion
5358

5459
#region Ctors
5560

@@ -89,14 +94,50 @@ public MultiShardConnection(IEnumerable<ShardLocation> shardLocations, string co
8994
InitializeShardConnections(shardLocations);
9095
}
9196

97+
/// <summary>
98+
/// Initializes a new instance of the <see cref="MultiShardConnection"/> class.
99+
/// </summary>
100+
/// <param name="shards">The collection of <see cref="Shard"/>s used for this connection instances.</param>
101+
/// <param name="connectionInfo">
102+
/// These credentials will be used to connect to the <see cref="Shard"/>s.
103+
/// The same credentials are used on all shards.
104+
/// Therefore, all shards need to provide the appropriate permissions for these credentials to execute the command.
105+
/// </param>
106+
/// <remarks>
107+
/// Multiple Active Result Sets (MARS) are not supported and are disabled for any processing at the shards.
108+
/// </remarks>
109+
public MultiShardConnection(IEnumerable<Shard> shards, SqlConnectionInfo connectionInfo)
110+
{
111+
InitializeConnectionInfo(connectionInfo);
112+
InitializeShardConnections(shards);
113+
}
114+
115+
/// <summary>
116+
/// Initializes a new instance of the <see cref="MultiShardConnection"/> class.
117+
/// </summary>
118+
/// <param name="shardLocations">The collection of <see cref="ShardLocation"/>s used for this connection instances.</param>
119+
/// <param name="connectionInfo">
120+
/// These credentials will be used to connect to the <see cref="Shard"/>s.
121+
/// The same credentials are used on all shards.
122+
/// Therefore, all shards need to provide the appropriate permissions for these credentials to execute the command.
123+
/// </param>
124+
/// <remarks>
125+
/// Multiple Active Result Sets (MARS) are not supported and are disabled for any processing at the shards.
126+
/// </remarks>
127+
public MultiShardConnection(IEnumerable<ShardLocation> shardLocations, SqlConnectionInfo connectionInfo)
128+
{
129+
InitializeConnectionInfo(connectionInfo);
130+
InitializeShardConnections(shardLocations);
131+
}
132+
92133
/// <summary>
93134
/// Creates an instance of this class
94135
/// /* TEST ONLY */
95136
/// </summary>
96137
/// <param name="shardConnections">Connections to the shards</param>
97138
internal MultiShardConnection(List<Tuple<ShardLocation, DbConnection>> shardConnections)
98139
{
99-
this.ShardConnections = shardConnections;
140+
this._shardConnections = shardConnections;
100141
}
101142

102143
#endregion
@@ -119,17 +160,11 @@ public IEnumerable<ShardLocation> ShardLocations
119160
{
120161
get
121162
{
122-
return this.ShardConnections.Select(s => s.Item1);
163+
return this._shardConnections.Select(s => s.Item1);
123164
}
124165
}
125166

126-
internal List<Tuple<ShardLocation, DbConnection>> ShardConnections
127-
{
128-
get;
129-
private set;
130-
}
131-
132-
#endregion
167+
#endregion
133168

134169
#region Public Methods
135170

@@ -152,7 +187,7 @@ public void Dispose()
152187
if (!_disposed)
153188
{
154189
// Dispose off the shard connections
155-
this.ShardConnections.ForEach(
190+
this._shardConnections.ForEach(
156191
(c) =>
157192
{
158193
if (c.Item2 != null)
@@ -167,9 +202,35 @@ public void Dispose()
167202
}
168203
}
169204

170-
#endregion
205+
#endregion
206+
207+
208+
#region Internal methods
209+
210+
/// <summary>
211+
/// Gets the shard connections
212+
/// </summary>
213+
internal List<Tuple<ShardLocation, DbConnection>> GetShardConnections()
214+
{
215+
// Refresh the access tokens for all shard connections.
216+
// (Null check because unit tests use internal code path which doesn't initialize _connectionInfo).
217+
if (this._connectionInfo != null)
218+
{
219+
foreach (var shardConnection in _shardConnections)
220+
{
221+
if (shardConnection.Item2 != null)
222+
{
223+
this._connectionInfo.RefreshAccessToken(shardConnection.Item2);
224+
}
225+
}
226+
}
171227

172-
#region Helpers
228+
return _shardConnections;
229+
}
230+
231+
#endregion
232+
233+
#region Helpers
173234

174235
private static void ValidateNotEmpty<T>(
175236
IEnumerable<T> namedCollection,
@@ -181,11 +242,28 @@ private static void ValidateNotEmpty<T>(
181242
}
182243
}
183244

245+
private void InitializeConnectionInfo(SqlConnectionInfo connectionInfo)
246+
{
247+
if (connectionInfo == null)
248+
{
249+
throw new ArgumentNullException("connectionInfo");
250+
}
251+
252+
string updatedConnectionString = InitializeConnectionString(connectionInfo.ConnectionString);
253+
this._connectionInfo = connectionInfo.CloneWithUpdatedConnectionString(updatedConnectionString);
254+
}
255+
184256
private void InitializeConnectionInfo(string connectionString)
257+
{
258+
string updatedConnectionString = InitializeConnectionString(connectionString);
259+
this._connectionInfo = new SqlConnectionInfo(updatedConnectionString);
260+
}
261+
262+
private static string InitializeConnectionString(string connectionString)
185263
{
186264
if (connectionString == null)
187265
{
188-
throw new ArgumentNullException("connectionString");
266+
throw new ArgumentNullException(nameof(connectionString));
189267
}
190268

191269
// Enhance the ApplicationName with this library's name as a suffix
@@ -206,7 +284,7 @@ private void InitializeConnectionInfo(string connectionString)
206284
throw new ArgumentException("InitialCatalog must not be set in the connectionStringBuilder");
207285
}
208286

209-
this._connectionInfo = new SqlConnectionInfo(connectionStringBuilder.ToString());
287+
return connectionStringBuilder.ToString();
210288
}
211289

212290
private void InitializeShardConnections(IEnumerable<Shard> shards)
@@ -220,8 +298,8 @@ private void InitializeShardConnections(IEnumerable<Shard> shards)
220298
this.Shards = shards.ToList();
221299
ValidateNotEmpty(this.Shards, "shards");
222300

223-
this.ShardConnections = this.Shards.Select(
224-
s => CreateDbConnectionForLocation(s.Location, _connectionInfo)).ToList();
301+
this._shardConnections = (this.Shards.Select(
302+
s => CreateDbConnectionForLocation(s.Location, _connectionInfo)).ToList());
225303
}
226304

227305
private void InitializeShardConnections(IEnumerable<ShardLocation> shardLocations)
@@ -236,7 +314,7 @@ private void InitializeShardConnections(IEnumerable<ShardLocation> shardLocation
236314
ValidateNotEmpty(shardLocationsList, "shardLocations");
237315

238316
this.Shards = null;
239-
this.ShardConnections = shardLocationsList.Select(
317+
this._shardConnections = shardLocationsList.Select(
240318
s => CreateDbConnectionForLocation(s, _connectionInfo)).ToList();
241319
}
242320

@@ -270,7 +348,7 @@ private static Tuple<ShardLocation, DbConnection> CreateDbConnectionForLocation(
270348
[System.Diagnostics.CodeAnalysis.SuppressMessage("Microsoft.Design", "CA1031:DoNotCatchGeneralExceptionTypes", Justification = "We do not want to throw on Close.")]
271349
internal void Close()
272350
{
273-
foreach (var conn in this.ShardConnections)
351+
foreach (var conn in this._shardConnections)
274352
{
275353
if (conn.Item2 != null && conn.Item2.State != ConnectionState.Closed)
276354
{

Src/ElasticScale.Client/ShardManagement/SqlStore/SqlConnectionInfo.cs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT license. See LICENSE file in the project root for full license information.
33

44
using System;
5+
using System.Data.Common;
56
using System.Data.SqlClient;
67

78
namespace Microsoft.Azure.SqlDatabase.ElasticScale.ShardManagement
@@ -90,6 +91,22 @@ internal SqlConnection CreateConnection()
9091
Credential = Credential,
9192
};
9293

94+
RefreshAccessToken(conn);
95+
96+
return conn;
97+
}
98+
99+
internal void RefreshAccessToken(DbConnection conn)
100+
{
101+
SqlConnection sqlConn= conn as SqlConnection;
102+
if (sqlConn != null)
103+
{
104+
RefreshAccessToken(sqlConn);
105+
}
106+
}
107+
108+
internal void RefreshAccessToken(SqlConnection conn)
109+
{
93110
if (AccessTokenFactory != null)
94111
{
95112
#if NET451
@@ -98,8 +115,6 @@ internal SqlConnection CreateConnection()
98115
conn.AccessToken = AccessTokenFactory();
99116
#endif
100117
}
101-
102-
return conn;
103118
}
104119

105120
/// <summary>

Test/ElasticScale.Query.UnitTests/MultiShardDataReaderTests.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -125,15 +125,15 @@ public void MyTestInitialize()
125125
// BUT, since we are writing tests at a lower level than that, we need to open
126126
// the connections manually here. Hence the loop below.
127127
//
128-
foreach (var conn in _shardConnection.ShardConnections)
128+
foreach (var conn in _shardConnection.GetShardConnections())
129129
{
130130
conn.Item2.Open();
131131
}
132132

133-
_conn1 = (SqlConnection)_shardConnection.ShardConnections[0].Item2;
134-
_conn2 = (SqlConnection)_shardConnection.ShardConnections[1].Item2;
135-
_conn3 = (SqlConnection)_shardConnection.ShardConnections[2].Item2;
136-
_conns = _shardConnection.ShardConnections.Select(x => (SqlConnection)x.Item2);
133+
_conn1 = (SqlConnection)_shardConnection.GetShardConnections()[0].Item2;
134+
_conn2 = (SqlConnection)_shardConnection.GetShardConnections()[1].Item2;
135+
_conn3 = (SqlConnection)_shardConnection.GetShardConnections()[2].Item2;
136+
_conns = _shardConnection.GetShardConnections().Select(x => (SqlConnection)x.Item2);
137137
}
138138

139139
/// <summary>
@@ -142,7 +142,7 @@ public void MyTestInitialize()
142142
[TestCleanup()]
143143
public void MyTestCleanup()
144144
{
145-
foreach (var conn in _shardConnection.ShardConnections)
145+
foreach (var conn in _shardConnection.GetShardConnections())
146146
{
147147
conn.Item2.Close();
148148
}

0 commit comments

Comments
 (0)