Skip to content

Commit dcf6ac4

Browse files
authored
Use an IBufferWriter<byte> to write the outgoing SSPI blob (#2452)
* Use an IBufferWriter<byte> to write the outgoing SSPI blob This change removes the need to pre-allocate anything for the outgoing blobs of SSPI generation. As part of this: - An internal implementation of ArrayBufferWriter is added for platforms that do not support it - SqlObjectPool is imbued with the ability to create/reset pooled objects - TdsParser/TdsLogin is updated to use pooled ArrayBufferWriter instances to generate SSPI blobs - Native methods are updated to take in Span/* for writeable byte[] - SSPIContextProvider signature is updated to take IBufferWriter
1 parent b5ce725 commit dcf6ac4

34 files changed

+616
-321
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8500,8 +8500,7 @@ private void WriteLoginData(SqlLogin rec,
85008500
int length,
85018501
int featureExOffset,
85028502
string clientInterfaceName,
8503-
byte[] outSSPIBuff,
8504-
uint outSSPILength)
8503+
ReadOnlySpan<byte> outSSPI)
85058504
{
85068505
try
85078506
{
@@ -8673,8 +8672,8 @@ private void WriteLoginData(SqlLogin rec,
86738672
WriteShort(offset, _physicalStateObj); // ibSSPI offset
86748673
if (rec.useSSPI)
86758674
{
8676-
WriteShort((int)outSSPILength, _physicalStateObj);
8677-
offset += (int)outSSPILength;
8675+
WriteShort(outSSPI.Length, _physicalStateObj);
8676+
offset += outSSPI.Length;
86788677
}
86798678
else
86808679
{
@@ -8729,7 +8728,7 @@ private void WriteLoginData(SqlLogin rec,
87298728

87308729
// send over SSPI data if we are using SSPI
87318730
if (rec.useSSPI)
8732-
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
8731+
_physicalStateObj.WriteByteSpan(outSSPI);
87338732

87348733
WriteString(rec.attachDBFilename, _physicalStateObj);
87358734
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParserStateObject.netcore.cs

-137
Original file line numberDiff line numberDiff line change
@@ -793,143 +793,6 @@ internal void WriteByte(byte b)
793793
_outBuff[_outBytesUsed++] = b;
794794
}
795795

796-
internal Task WriteByteSpan(ReadOnlySpan<byte> span, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
797-
{
798-
return WriteBytes(span, span.Length, 0, canAccumulate, completion);
799-
}
800-
801-
internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
802-
{
803-
return WriteBytes(ReadOnlySpan<byte>.Empty, len, offsetBuffer, canAccumulate, completion, b);
804-
}
805-
806-
//
807-
// Takes a span or a byte array and writes it to the buffer
808-
// If you pass in a span and a null array then the span wil be used.
809-
// If you pass in a non-null array then the array will be used and the span is ignored.
810-
// if the span cannot be written into the current packet then the remaining contents of the span are copied to a
811-
// new heap allocated array that will used to callback into the method to continue the write operation.
812-
private Task WriteBytes(ReadOnlySpan<byte> b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null, byte[] array = null)
813-
{
814-
if (array != null)
815-
{
816-
b = new ReadOnlySpan<byte>(array, offsetBuffer, len);
817-
}
818-
try
819-
{
820-
bool async = _parser._asyncWrite; // NOTE: We are capturing this now for the assert after the Task is returned, since WritePacket will turn off async if there is an exception
821-
Debug.Assert(async || _asyncWriteCount == 0);
822-
// Do we have to send out in packet size chunks, or can we rely on netlib layer to break it up?
823-
// would prefer to do something like:
824-
//
825-
// if (len > what we have room for || len > out buf)
826-
// flush buffer
827-
// UnsafeNativeMethods.Write(b)
828-
//
829-
830-
int offset = offsetBuffer;
831-
832-
Debug.Assert(b.Length >= len, "Invalid length sent to WriteBytes()!");
833-
834-
// loop through and write the entire array
835-
do
836-
{
837-
if ((_outBytesUsed + len) > _outBuff.Length)
838-
{
839-
// If the remainder of the data won't fit into the buffer, then we have to put
840-
// whatever we can into the buffer, and flush that so we can then put more into
841-
// the buffer on the next loop of the while.
842-
843-
int remainder = _outBuff.Length - _outBytesUsed;
844-
845-
// write the remainder
846-
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, remainder);
847-
ReadOnlySpan<byte> copyFrom = b.Slice(0, remainder);
848-
849-
Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length{copyFrom.Length:D} should be the same");
850-
851-
copyFrom.CopyTo(copyTo);
852-
853-
offset += remainder;
854-
_outBytesUsed += remainder;
855-
len -= remainder;
856-
b = b.Slice(remainder, len);
857-
858-
Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);
859-
860-
if (packetTask != null)
861-
{
862-
Task task = null;
863-
Debug.Assert(async, "Returned task in sync mode");
864-
if (completion == null)
865-
{
866-
completion = new TaskCompletionSource<object>();
867-
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
868-
}
869-
870-
if (array == null)
871-
{
872-
byte[] tempArray = new byte[len];
873-
Span<byte> copyTempTo = tempArray.AsSpan();
874-
875-
Debug.Assert(copyTempTo.Length == b.Length, $"copyTempTo.Length:{copyTempTo.Length} and copyTempFrom.Length:{b.Length:D} should be the same");
876-
877-
b.CopyTo(copyTempTo);
878-
array = tempArray;
879-
offset = 0;
880-
}
881-
882-
WriteBytesSetupContinuation(array, len, completion, offset, packetTask);
883-
return task;
884-
}
885-
}
886-
else
887-
{
888-
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
889-
// Else the remainder of the string will fit into the buffer, so copy it into the
890-
// buffer and then break out of the loop.
891-
892-
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, len);
893-
ReadOnlySpan<byte> copyFrom = b.Slice(0, len);
894-
895-
Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length:{copyFrom.Length:D} should be the same");
896-
897-
copyFrom.CopyTo(copyTo);
898-
899-
// handle out buffer bytes used counter
900-
_outBytesUsed += len;
901-
break;
902-
}
903-
} while (len > 0);
904-
905-
if (completion != null)
906-
{
907-
completion.SetResult(null);
908-
}
909-
return null;
910-
}
911-
catch (Exception e)
912-
{
913-
if (completion != null)
914-
{
915-
completion.SetException(e);
916-
return null;
917-
}
918-
else
919-
{
920-
throw;
921-
}
922-
}
923-
}
924-
925-
// This is in its own method to avoid always allocating the lambda in WriteBytes
926-
private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
927-
{
928-
AsyncHelper.ContinueTask(packetTask, completion,
929-
onSuccess: () => WriteBytes(ReadOnlySpan<byte>.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array)
930-
);
931-
}
932-
933796
// Dumps contents of buffer to SNI for network write.
934797
internal Task WritePacket(byte flushMode, bool canAccumulate = false)
935798
{

src/Microsoft.Data.SqlClient/netfx/src/Microsoft.Data.SqlClient.csproj

+7-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,10 @@
337337
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlDataSourceEnumerator.cs">
338338
<Link>Microsoft\Data\Sql\SqlDataSourceEnumerator.cs</Link>
339339
</Compile>
340-
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
340+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
341+
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
342+
</Compile>
343+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AAsyncCallContext.cs">
341344
<Link>Microsoft\Data\SqlClient\AAsyncCallContext.cs</Link>
342345
</Compile>
343346
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
@@ -361,6 +364,9 @@
361364
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AlwaysEncryptedKeyConverter.cs">
362365
<Link>Microsoft\Data\SqlClient\AlwaysEncryptedKeyConverter.cs</Link>
363366
</Compile>
367+
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
368+
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
369+
</Compile>
364370
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\AzureAttestationBasedEnclaveProvider.cs">
365371
<Link>Microsoft\Data\SqlClient\AzureAttestationBasedEnclaveProvider.cs</Link>
366372
</Compile>

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8972,8 +8972,7 @@ private void WriteLoginData(SqlLogin rec,
89728972
int length,
89738973
int featureExOffset,
89748974
string clientInterfaceName,
8975-
byte[] outSSPIBuff,
8976-
uint outSSPILength)
8975+
ReadOnlySpan<byte> outSSPI)
89778976
{
89788977
try
89798978
{
@@ -9145,8 +9144,8 @@ private void WriteLoginData(SqlLogin rec,
91459144
WriteShort(offset, _physicalStateObj); // ibSSPI offset
91469145
if (rec.useSSPI)
91479146
{
9148-
WriteShort((int)outSSPILength, _physicalStateObj);
9149-
offset += (int)outSSPILength;
9147+
WriteShort(outSSPI.Length, _physicalStateObj);
9148+
offset += outSSPI.Length;
91509149
}
91519150
else
91529151
{
@@ -9205,7 +9204,7 @@ private void WriteLoginData(SqlLogin rec,
92059204

92069205
// send over SSPI data if we are using SSPI
92079206
if (rec.useSSPI)
9208-
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
9207+
_physicalStateObj.WriteByteSpan(outSSPI);
92099208

92109209
WriteString(rec.attachDBFilename, _physicalStateObj);
92119210
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs

-96
Original file line numberDiff line numberDiff line change
@@ -868,102 +868,6 @@ internal void WriteByte(byte b)
868868
_outBuff[_outBytesUsed++] = b;
869869
}
870870

871-
internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
872-
{
873-
try
874-
{
875-
bool async = _parser._asyncWrite; // NOTE: We are capturing this now for the assert after the Task is returned, since WritePacket will turn off async if there is an exception
876-
Debug.Assert(async || _asyncWriteCount == 0);
877-
// Do we have to send out in packet size chunks, or can we rely on netlib layer to break it up?
878-
// would prefer to do something like:
879-
//
880-
// if (len > what we have room for || len > out buf)
881-
// flush buffer
882-
// UnsafeNativeMethods.Write(b)
883-
//
884-
885-
int offset = offsetBuffer;
886-
887-
Debug.Assert(b.Length >= len, "Invalid length sent to WriteByteArray()!");
888-
889-
// loop through and write the entire array
890-
do
891-
{
892-
if ((_outBytesUsed + len) > _outBuff.Length)
893-
{
894-
// If the remainder of the data won't fit into the buffer, then we have to put
895-
// whatever we can into the buffer, and flush that so we can then put more into
896-
// the buffer on the next loop of the while.
897-
898-
int remainder = _outBuff.Length - _outBytesUsed;
899-
900-
// write the remainder
901-
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, remainder);
902-
903-
// handle counters
904-
offset += remainder;
905-
_outBytesUsed += remainder;
906-
len -= remainder;
907-
908-
Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);
909-
910-
if (packetTask != null)
911-
{
912-
Task task = null;
913-
Debug.Assert(async, "Returned task in sync mode");
914-
if (completion == null)
915-
{
916-
completion = new TaskCompletionSource<object>();
917-
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
918-
}
919-
WriteByteArraySetupContinuation(b, len, completion, offset, packetTask);
920-
return task;
921-
}
922-
923-
}
924-
else
925-
{
926-
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
927-
// Else the remainder of the string will fit into the buffer, so copy it into the
928-
// buffer and then break out of the loop.
929-
930-
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, len);
931-
932-
// handle out buffer bytes used counter
933-
_outBytesUsed += len;
934-
break;
935-
}
936-
} while (len > 0);
937-
938-
if (completion != null)
939-
{
940-
completion.SetResult(null);
941-
}
942-
return null;
943-
}
944-
catch (Exception e)
945-
{
946-
if (completion != null)
947-
{
948-
completion.SetException(e);
949-
return null;
950-
}
951-
else
952-
{
953-
throw;
954-
}
955-
}
956-
}
957-
958-
// This is in its own method to avoid always allocating the lambda in WriteByteArray
959-
private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
960-
{
961-
AsyncHelper.ContinueTask(packetTask, completion,
962-
() => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion),
963-
connectionToDoom: _parser.Connection
964-
);
965-
}
966-
967871
// Dumps contents of buffer to SNI for network write.
968872
internal Task WritePacket(byte flushMode, bool canAccumulate = false)
969873
{

src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/ISniNativeMethods.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ unsafe uint SniSecGenClientContextWrapper(
7272
SNIHandle pConn,
7373
byte* pIn,
7474
uint cbIn,
75-
byte[] pOut,
75+
byte* pOut,
7676
ref uint pcbOut,
7777
out bool pfDone,
7878
byte* szServerInfo,

src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethods.netcore.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public unsafe uint SniSecGenClientContextWrapper(
111111
SNIHandle pConn,
112112
byte* pIn,
113113
uint cbIn,
114-
byte[] pOut,
114+
byte* pOut,
115115
ref uint pcbOut,
116116
out bool pfDone,
117117
byte* szServerInfo,
@@ -265,7 +265,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
265265
[In] SNIHandle pConn,
266266
[In, Out] byte* pIn,
267267
uint cbIn,
268-
[In, Out] byte[] pOut,
268+
[In, Out] byte* pOut,
269269
[In] ref uint pcbOut,
270270
[MarshalAs(UnmanagedType.Bool)] out bool pfDone,
271271
byte* szServerInfo,

src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsArm64.netfx.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ public unsafe uint SniSecGenClientContextWrapper(
111111
SNIHandle pConn,
112112
byte* pIn,
113113
uint cbIn,
114-
byte[] pOut,
114+
byte* pOut,
115115
ref uint pcbOut,
116116
out bool pfDone,
117117
byte* szServerInfo,
@@ -265,7 +265,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
265265
[In] SNIHandle pConn,
266266
[In, Out] byte* pIn,
267267
uint cbIn,
268-
[In, Out] byte[] pOut,
268+
[In, Out] byte* pOut,
269269
[In] ref uint pcbOut,
270270
[MarshalAs(UnmanagedType.Bool)] out bool pfDone,
271271
byte* szServerInfo,

src/Microsoft.Data.SqlClient/src/Interop/Windows/Sni/SniNativeMethodsNotSupported.netfx.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ public unsafe uint SniSecGenClientContextWrapper(
115115
SNIHandle pConn,
116116
byte* pIn,
117117
uint cbIn,
118-
byte[] pOut,
118+
byte* pOut,
119119
ref uint pcbOut,
120120
out bool pfDone,
121121
byte* szServerInfo,

0 commit comments

Comments
 (0)