Skip to content

Commit 1ec663f

Browse files
committed
Use an IBufferWriter<byte> to write the outgoing SSPI blob
This change removes the need to pre-allocate anything for the outgoing blobs of SSPI generation. As part of this: - An internal implementation of ArrayBufferWriter is added for platforms that do not support it - SqlObjectPool is imbued with the ability to create/reset pooled objects - TdsParser/TdsLogin is updated to use pooled ArrayBufferWriter instances to generate SSPI blobs - Native methods are updated to take in Span/* for writeable byte[] - SSPIContextProvider signature is updated to take IBufferWriter
1 parent 4c0f013 commit 1ec663f

18 files changed

+407
-77
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs

+6-3
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
317317
[In] SNIHandle pConn,
318318
[In, Out] byte* pIn,
319319
uint cbIn,
320-
[In, Out] byte[] pOut,
320+
[In, Out] byte* pOut,
321321
[In] ref uint pcbOut,
322322
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
323323
byte* szServerInfo,
@@ -471,17 +471,20 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
471471
}
472472
}
473473

474-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
474+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
475475
{
476+
sendLength = (uint)outBuff.Length;
477+
476478
fixed (byte* pin_serverUserName = &serverUserName[0])
477479
fixed (byte* pInBuff = inBuff)
480+
fixed (byte* pOutBuff = outBuff)
478481
{
479482
bool local_fDone;
480483
return SNISecGenClientContextWrapper(
481484
pConnectionObject,
482485
pInBuff,
483486
(uint)inBuff.Length,
484-
OutBuff,
487+
pOutBuff,
485488
ref sendLength,
486489
out local_fDone,
487490
pin_serverUserName,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft.Data.SqlClient.csproj

+3
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,9 @@
503503
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
504504
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
505505
</Compile>
506+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
507+
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
508+
</Compile>
506509
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
507510
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
508511
</Compile>

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+6-7
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ internal class SNIProxy
3232
/// </summary>
3333
/// <param name="sspiClientContextStatus">SSPI client context status</param>
3434
/// <param name="receivedBuff">Receive buffer</param>
35-
/// <param name="sendBuff">Send buffer</param>
35+
/// <param name="sendWriter">Writer for send buffer</param>
3636
/// <param name="serverName">Service Principal Name buffer</param>
3737
/// <returns>SNI error code</returns>
38-
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, byte[][] serverName)
38+
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
3939
{
4040
// TODO: this should use ReadOnlyMemory all the way through
4141
byte[] array = null;
@@ -46,10 +46,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
4646
receivedBuff.CopyTo(array);
4747
}
4848

49-
GenSspiClientContext(sspiClientContextStatus, array, ref sendBuff, serverName);
49+
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName);
5050
}
5151

52-
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
52+
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
5353
{
5454
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
5555
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -103,10 +103,9 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
103103
outSecurityBuffer.token = null;
104104
}
105105

106-
sendBuff = outSecurityBuffer.token;
107-
if (sendBuff == null)
106+
if (outSecurityBuffer.token is { } token)
108107
{
109-
sendBuff = Array.Empty<byte>();
108+
sendWriter.Write(token);
110109
}
111110

112111
sspiClientContextStatus.SecurityContext = securityContext;

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8120,8 +8120,7 @@ private void WriteLoginData(SqlLogin rec,
81208120
int length,
81218121
int featureExOffset,
81228122
string clientInterfaceName,
8123-
byte[] outSSPIBuff,
8124-
uint outSSPILength)
8123+
ReadOnlySpan<byte> outSSPI)
81258124
{
81268125
try
81278126
{
@@ -8289,8 +8288,8 @@ private void WriteLoginData(SqlLogin rec,
82898288
WriteShort(offset, _physicalStateObj); // ibSSPI offset
82908289
if (rec.useSSPI)
82918290
{
8292-
WriteShort((int)outSSPILength, _physicalStateObj);
8293-
offset += (int)outSSPILength;
8291+
WriteShort(outSSPI.Length, _physicalStateObj);
8292+
offset += outSSPI.Length;
82948293
}
82958294
else
82968295
{
@@ -8345,7 +8344,7 @@ private void WriteLoginData(SqlLogin rec,
83458344

83468345
// send over SSPI data if we are using SSPI
83478346
if (rec.useSSPI)
8348-
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
8347+
_physicalStateObj.WriteByteSpan(outSSPI);
83498348

83508349
WriteString(rec.attachDBFilename, _physicalStateObj);
83518350
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))

src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj

+6
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,12 @@
183183
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlNotificationRequest.cs">
184184
<Link>Microsoft\Data\Sql\SqlNotificationRequest.cs</Link>
185185
</Compile>
186+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
187+
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
188+
</Compile>
189+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
190+
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
191+
</Compile>
186192
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
187193
<Link>Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs</Link>
188194
</Compile>

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperARM64.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
118118
[In] SNIHandle pConn,
119119
[In, Out] byte* pIn,
120120
uint cbIn,
121-
[In, Out] byte[] pOut,
121+
[In, Out] byte* pOut,
122122
[In] ref uint pcbOut,
123123
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
124124
byte* szServerInfo,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
118118
[In] SNIHandle pConn,
119119
[In, Out] byte* pIn,
120120
uint cbIn,
121-
[In, Out] byte[] pOut,
121+
[In, Out] byte* pOut,
122122
[In] ref uint pcbOut,
123123
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
124124
byte* szServerInfo,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
118118
[In] SNIHandle pConn,
119119
[In, Out] byte* pIn,
120120
uint cbIn,
121-
[In, Out] byte[] pOut,
121+
[In, Out] byte* pOut,
122122
[In] ref uint pcbOut,
123123
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
124124
byte* szServerInfo,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs

+9-6
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf,
890890
private static unsafe uint SNISecGenClientContextWrapper(
891891
[In] SNIHandle pConn,
892892
[In, Out] ReadOnlySpan<byte> pIn,
893-
[In, Out] byte[] pOut,
893+
[In, Out] Span<byte> pOut,
894894
[In] ref uint pcbOut,
895895
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
896896
byte* szServerInfo,
@@ -899,15 +899,16 @@ private static unsafe uint SNISecGenClientContextWrapper(
899899
[MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword)
900900
{
901901
fixed (byte* pInPtr = pIn)
902+
fixed (byte* pOutPtr = pOut)
902903
{
903904
switch (s_architecture)
904905
{
905906
case System.Runtime.InteropServices.Architecture.Arm64:
906-
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
907+
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
907908
case System.Runtime.InteropServices.Architecture.X64:
908-
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
909+
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
909910
case System.Runtime.InteropServices.Architecture.X86:
910-
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
911+
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
911912
default:
912913
throw ADP.SNIPlatformNotSupported(s_architecture.ToString());
913914
}
@@ -1380,15 +1381,17 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w
13801381
}
13811382
}
13821383

1383-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
1384+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
13841385
{
1386+
sendLength = (uint)outBuff.Length;
1387+
13851388
fixed (byte* pin_serverUserName = &serverUserName[0])
13861389
{
13871390
bool local_fDone;
13881391
return SNISecGenClientContextWrapper(
13891392
pConnectionObject,
13901393
inBuff,
1391-
OutBuff,
1394+
outBuff,
13921395
ref sendLength,
13931396
out local_fDone,
13941397
pin_serverUserName,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8919,8 +8919,7 @@ private void WriteLoginData(SqlLogin rec,
89198919
int length,
89208920
int featureExOffset,
89218921
string clientInterfaceName,
8922-
byte[] outSSPIBuff,
8923-
uint outSSPILength)
8922+
ReadOnlySpan<byte> outSSPI)
89248923
{
89258924
try
89268925
{
@@ -9091,8 +9090,8 @@ private void WriteLoginData(SqlLogin rec,
90919090
WriteShort(offset, _physicalStateObj); // ibSSPI offset
90929091
if (rec.useSSPI)
90939092
{
9094-
WriteShort((int)outSSPILength, _physicalStateObj);
9095-
offset += (int)outSSPILength;
9093+
WriteShort(outSSPI.Length, _physicalStateObj);
9094+
offset += outSSPI.Length;
90969095
}
90979096
else
90989097
{
@@ -9151,7 +9150,7 @@ private void WriteLoginData(SqlLogin rec,
91519150

91529151
// send over SSPI data if we are using SSPI
91539152
if (rec.useSSPI)
9154-
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
9153+
_physicalStateObj.WriteByteSpan(outSSPI);
91559154

91569155
WriteString(rec.attachDBFilename, _physicalStateObj);
91579156
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs

+48-7
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,28 @@ internal void WriteByte(byte b)
930930
// set byte in buffer and increment the counter for number of bytes used in the out buffer
931931
_outBuff[_outBytesUsed++] = b;
932932
}
933+
internal Task WriteByteSpan(ReadOnlySpan<byte> span, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
934+
{
935+
return WriteBytes(span, span.Length, 0, canAccumulate, completion);
936+
}
933937

934938
internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
935939
{
940+
return WriteBytes(ReadOnlySpan<byte>.Empty, len, offsetBuffer, canAccumulate, completion, b);
941+
}
942+
943+
//
944+
// Takes a span or a byte array and writes it to the buffer
945+
// If you pass in a span and a null array then the span wil be used.
946+
// If you pass in a non-null array then the array will be used and the span is ignored.
947+
// if the span cannot be written into the current packet then the remaining contents of the span are copied to a
948+
// new heap allocated array that will used to callback into the method to continue the write operation.
949+
private Task WriteBytes(ReadOnlySpan<byte> b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null, byte[] array = null)
950+
{
951+
if (array != null)
952+
{
953+
b = new ReadOnlySpan<byte>(array, offsetBuffer, len);
954+
}
936955
try
937956
{
938957
TdsParser.ReliabilitySection.Assert("unreliable call to WriteByteArray"); // you need to setup for a thread abort somewhere before you call this method
@@ -949,7 +968,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
949968

950969
int offset = offsetBuffer;
951970

952-
Debug.Assert(b.Length >= len, "Invalid length sent to WriteByteArray()!");
971+
Debug.Assert(b.Length >= len, "Invalid length sent to WriteBytes()!");
953972

954973
// loop through and write the entire array
955974
do
@@ -963,12 +982,17 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
963982
int remainder = _outBuff.Length - _outBytesUsed;
964983

965984
// write the remainder
966-
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, remainder);
985+
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, remainder);
986+
ReadOnlySpan<byte> copyFrom = b.Slice(0, remainder);
987+
988+
Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length{copyFrom.Length:D} should be the same");
989+
990+
copyFrom.CopyTo(copyTo);
967991

968-
// handle counters
969992
offset += remainder;
970993
_outBytesUsed += remainder;
971994
len -= remainder;
995+
b = b.Slice(remainder, len);
972996

973997
Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);
974998

@@ -981,18 +1005,35 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
9811005
completion = new TaskCompletionSource<object>();
9821006
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
9831007
}
984-
WriteByteArraySetupContinuation(b, len, completion, offset, packetTask);
1008+
1009+
if (array == null)
1010+
{
1011+
byte[] tempArray = new byte[len];
1012+
Span<byte> copyTempTo = tempArray.AsSpan();
1013+
1014+
Debug.Assert(copyTempTo.Length == b.Length, $"copyTempTo.Length:{copyTempTo.Length} and copyTempFrom.Length:{b.Length:D} should be the same");
1015+
1016+
b.CopyTo(copyTempTo);
1017+
array = tempArray;
1018+
offset = 0;
1019+
}
1020+
1021+
WriteBytesSetupContinuation(array, len, completion, offset, packetTask);
9851022
return task;
9861023
}
987-
9881024
}
9891025
else
9901026
{
9911027
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
9921028
// Else the remainder of the string will fit into the buffer, so copy it into the
9931029
// buffer and then break out of the loop.
9941030

995-
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, len);
1031+
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, len);
1032+
ReadOnlySpan<byte> copyFrom = b.Slice(0, len);
1033+
1034+
Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length:{copyFrom.Length:D} should be the same");
1035+
1036+
copyFrom.CopyTo(copyTo);
9961037

9971038
// handle out buffer bytes used counter
9981039
_outBytesUsed += len;
@@ -1021,7 +1062,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
10211062
}
10221063

10231064
// This is in its own method to avoid always allocating the lambda in WriteByteArray
1024-
private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
1065+
private void WriteBytesSetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
10251066
{
10261067
AsyncHelper.ContinueTask(packetTask, completion,
10271068
() => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion),

0 commit comments

Comments
 (0)