Skip to content

Commit 2b73e80

Browse files
committed
Return IMemoryOwner<byte>
1 parent 7139cea commit 2b73e80

14 files changed

+176
-85
lines changed

src/Microsoft.Data.SqlClient/netcore/src/Interop/SNINativeMethodWrapper.Windows.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
317317
[In] SNIHandle pConn,
318318
[In, Out] byte* pIn,
319319
uint cbIn,
320-
[In, Out] byte[] pOut,
320+
[In, Out] byte* pOut,
321321
[In] ref uint pcbOut,
322322
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
323323
byte* szServerInfo,
@@ -471,17 +471,18 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
471471
}
472472
}
473473

474-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
474+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, ref uint sendLength, byte[] serverUserName)
475475
{
476476
fixed (byte* pin_serverUserName = &serverUserName[0])
477477
fixed (byte* pInBuff = inBuff)
478+
fixed (byte* pOutBuff = outBuff)
478479
{
479480
bool local_fDone;
480481
return SNISecGenClientContextWrapper(
481482
pConnectionObject,
482483
pInBuff,
483484
(uint)inBuff.Length,
484-
OutBuff,
485+
pOutBuff,
485486
ref sendLength,
486487
out local_fDone,
487488
pin_serverUserName,

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SNI/SNIProxy.cs

+6-11
Original file line numberDiff line numberDiff line change
@@ -32,26 +32,25 @@ internal class SNIProxy
3232
/// </summary>
3333
/// <param name="sspiClientContextStatus">SSPI client context status</param>
3434
/// <param name="receivedBuff">Receive buffer</param>
35-
/// <param name="sendBuff">Send buffer</param>
3635
/// <param name="serverName">Service Principal Name buffer</param>
37-
/// <returns>SNI error code</returns>
38-
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, byte[][] serverName)
36+
/// <returns>Memory for response</returns>
37+
internal static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, byte[][] serverName)
3938
{
4039
// TODO: this should use ReadOnlyMemory all the way through
4140
var array = ArrayPool<byte>.Shared.Rent(receivedBuff.Length);
4241

4342
try
4443
{
4544
receivedBuff.CopyTo(array);
46-
GenSspiClientContext(sspiClientContextStatus, array, receivedBuff.Length, ref sendBuff, serverName);
45+
return GenSspiClientContext(sspiClientContextStatus, array, receivedBuff.Length, serverName);
4746
}
4847
finally
4948
{
5049
ArrayPool<byte>.Shared.Return(array);
5150
}
5251
}
5352

54-
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, int receivedBuffLength, ref byte[] sendBuff, byte[][] serverName)
53+
private static IMemoryOwner<byte> GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, int receivedBuffLength, byte[][] serverName)
5554
{
5655
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
5756
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
@@ -105,12 +104,6 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
105104
outSecurityBuffer.token = null;
106105
}
107106

108-
sendBuff = outSecurityBuffer.token;
109-
if (sendBuff == null)
110-
{
111-
sendBuff = Array.Empty<byte>();
112-
}
113-
114107
sspiClientContextStatus.SecurityContext = securityContext;
115108
sspiClientContextStatus.ContextFlags = contextFlags;
116109
sspiClientContextStatus.CredentialsHandle = credentialsHandle;
@@ -130,6 +123,8 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
130123
throw new InvalidOperationException(SQLMessage.SSPIGenerateError() + Environment.NewLine + statusCode);
131124
}
132125
}
126+
127+
return outSecurityBuffer.token.AsMemoryOwner();
133128
}
134129

135130
private static bool IsErrorStatus(SecurityStatusPalErrorCode errorCode)

src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/TdsParser.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8122,8 +8122,7 @@ private void WriteLoginData(SqlLogin rec,
81228122
int length,
81238123
int featureExOffset,
81248124
string clientInterfaceName,
8125-
byte[] outSSPIBuff,
8126-
uint outSSPILength)
8125+
ReadOnlySpan<byte> outSSPIBuff)
81278126
{
81288127
try
81298128
{
@@ -8291,8 +8290,8 @@ private void WriteLoginData(SqlLogin rec,
82918290
WriteShort(offset, _physicalStateObj); // ibSSPI offset
82928291
if (rec.useSSPI)
82938292
{
8294-
WriteShort((int)outSSPILength, _physicalStateObj);
8295-
offset += (int)outSSPILength;
8293+
WriteShort((int)outSSPIBuff.Length, _physicalStateObj);
8294+
offset += outSSPIBuff.Length;
82968295
}
82978296
else
82988297
{
@@ -8347,7 +8346,7 @@ private void WriteLoginData(SqlLogin rec,
83478346

83488347
// send over SSPI data if we are using SSPI
83498348
if (rec.useSSPI)
8350-
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
8349+
_physicalStateObj.WriteByteSpan(outSSPIBuff);
83518350

83528351
WriteString(rec.attachDBFilename, _physicalStateObj);
83538352
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperARM64.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
118118
[In] SNIHandle pConn,
119119
[In, Out] byte* pIn,
120120
uint cbIn,
121-
[In, Out] byte[] pOut,
121+
[In, Out] byte* pOut,
122122
[In] ref uint pcbOut,
123123
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
124124
byte* szServerInfo,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX64.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
118118
[In] SNIHandle pConn,
119119
[In, Out] byte* pIn,
120120
uint cbIn,
121-
[In, Out] byte[] pOut,
121+
[In, Out] byte* pOut,
122122
[In] ref uint pcbOut,
123123
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
124124
byte* szServerInfo,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeManagedWrapperX86.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
118118
[In] SNIHandle pConn,
119119
[In, Out] byte* pIn,
120120
uint cbIn,
121-
[In, Out] byte[] pOut,
121+
[In, Out] byte* pOut,
122122
[In] ref uint pcbOut,
123123
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
124124
byte* szServerInfo,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/Interop/SNINativeMethodWrapper.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,7 @@ private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf,
890890
private static unsafe uint SNISecGenClientContextWrapper(
891891
[In] SNIHandle pConn,
892892
[In, Out] ReadOnlySpan<byte> pIn,
893-
[In, Out] byte[] pOut,
893+
[In, Out] byte* pOut,
894894
[In] ref uint pcbOut,
895895
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
896896
byte* szServerInfo,
@@ -1380,15 +1380,16 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w
13801380
}
13811381
}
13821382

1383-
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
1383+
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> OutBuff, ref uint sendLength, byte[] serverUserName)
13841384
{
13851385
fixed (byte* pin_serverUserName = &serverUserName[0])
1386+
fixed (byte* pOutBuff = OutBuff)
13861387
{
13871388
bool local_fDone;
13881389
return SNISecGenClientContextWrapper(
13891390
pConnectionObject,
13901391
inBuff,
1391-
OutBuff,
1392+
pOutBuff,
13921393
ref sendLength,
13931394
out local_fDone,
13941395
pin_serverUserName,

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParser.cs

+4-5
Original file line numberDiff line numberDiff line change
@@ -8919,8 +8919,7 @@ private void WriteLoginData(SqlLogin rec,
89198919
int length,
89208920
int featureExOffset,
89218921
string clientInterfaceName,
8922-
byte[] outSSPIBuff,
8923-
uint outSSPILength)
8922+
ReadOnlySpan<byte> outSSPIBuff)
89248923
{
89258924
try
89268925
{
@@ -9091,8 +9090,8 @@ private void WriteLoginData(SqlLogin rec,
90919090
WriteShort(offset, _physicalStateObj); // ibSSPI offset
90929091
if (rec.useSSPI)
90939092
{
9094-
WriteShort((int)outSSPILength, _physicalStateObj);
9095-
offset += (int)outSSPILength;
9093+
WriteShort((int)outSSPIBuff.Length, _physicalStateObj);
9094+
offset += (int)outSSPIBuff.Length;
90969095
}
90979096
else
90989097
{
@@ -9151,7 +9150,7 @@ private void WriteLoginData(SqlLogin rec,
91519150

91529151
// send over SSPI data if we are using SSPI
91539152
if (rec.useSSPI)
9154-
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
9153+
_physicalStateObj.WriteByteSpan(outSSPIBuff);
91559154

91569155
WriteString(rec.attachDBFilename, _physicalStateObj);
91579156
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))

src/Microsoft.Data.SqlClient/netfx/src/Microsoft/Data/SqlClient/TdsParserStateObject.netfx.cs

+51-12
Original file line numberDiff line numberDiff line change
@@ -931,12 +931,30 @@ internal void WriteByte(byte b)
931931
_outBuff[_outBytesUsed++] = b;
932932
}
933933

934+
internal Task WriteByteSpan(ReadOnlySpan<byte> span, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
935+
{
936+
return WriteBytes(span, span.Length, 0, canAccumulate, completion);
937+
}
938+
934939
internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
935940
{
941+
return WriteBytes(ReadOnlySpan<byte>.Empty, len, offsetBuffer, canAccumulate, completion, b);
942+
}
943+
944+
//
945+
// Takes a span or a byte array and writes it to the buffer
946+
// If you pass in a span and a null array then the span wil be used.
947+
// If you pass in a non-null array then the array will be used and the span is ignored.
948+
// if the span cannot be written into the current packet then the remaining contents of the span are copied to a
949+
// new heap allocated array that will used to callback into the method to continue the write operation.
950+
private Task WriteBytes(ReadOnlySpan<byte> b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null, byte[] array = null)
951+
{
952+
if (array != null)
953+
{
954+
b = new ReadOnlySpan<byte>(array, offsetBuffer, len);
955+
}
936956
try
937957
{
938-
TdsParser.ReliabilitySection.Assert("unreliable call to WriteByteArray"); // you need to setup for a thread abort somewhere before you call this method
939-
940958
bool async = _parser._asyncWrite; // NOTE: We are capturing this now for the assert after the Task is returned, since WritePacket will turn off async if there is an exception
941959
Debug.Assert(async || _asyncWriteCount == 0);
942960
// Do we have to send out in packet size chunks, or can we rely on netlib layer to break it up?
@@ -949,7 +967,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
949967

950968
int offset = offsetBuffer;
951969

952-
Debug.Assert(b.Length >= len, "Invalid length sent to WriteByteArray()!");
970+
Debug.Assert(b.Length >= len, "Invalid length sent to WriteBytes()!");
953971

954972
// loop through and write the entire array
955973
do
@@ -963,12 +981,17 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
963981
int remainder = _outBuff.Length - _outBytesUsed;
964982

965983
// write the remainder
966-
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, remainder);
984+
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, remainder);
985+
ReadOnlySpan<byte> copyFrom = b.Slice(0, remainder);
986+
987+
Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length{copyFrom.Length:D} should be the same");
988+
989+
copyFrom.CopyTo(copyTo);
967990

968-
// handle counters
969991
offset += remainder;
970992
_outBytesUsed += remainder;
971993
len -= remainder;
994+
b = b.Slice(remainder, len);
972995

973996
Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);
974997

@@ -981,18 +1004,35 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
9811004
completion = new TaskCompletionSource<object>();
9821005
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
9831006
}
984-
WriteByteArraySetupContinuation(b, len, completion, offset, packetTask);
1007+
1008+
if (array == null)
1009+
{
1010+
byte[] tempArray = new byte[len];
1011+
Span<byte> copyTempTo = tempArray.AsSpan();
1012+
1013+
Debug.Assert(copyTempTo.Length == b.Length, $"copyTempTo.Length:{copyTempTo.Length} and copyTempFrom.Length:{b.Length:D} should be the same");
1014+
1015+
b.CopyTo(copyTempTo);
1016+
array = tempArray;
1017+
offset = 0;
1018+
}
1019+
1020+
WriteBytesSetupContinuation(array, len, completion, offset, packetTask);
9851021
return task;
9861022
}
987-
9881023
}
9891024
else
9901025
{
9911026
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
9921027
// Else the remainder of the string will fit into the buffer, so copy it into the
9931028
// buffer and then break out of the loop.
9941029

995-
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, len);
1030+
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, len);
1031+
ReadOnlySpan<byte> copyFrom = b.Slice(0, len);
1032+
1033+
Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length:{copyFrom.Length:D} should be the same");
1034+
1035+
copyFrom.CopyTo(copyTo);
9961036

9971037
// handle out buffer bytes used counter
9981038
_outBytesUsed += len;
@@ -1021,12 +1061,11 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
10211061
}
10221062

10231063
// This is in its own method to avoid always allocating the lambda in WriteByteArray
1024-
private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
1064+
private void WriteBytesSetupContinuation(byte[] array, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
10251065
{
10261066
AsyncHelper.ContinueTask(packetTask, completion,
1027-
() => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion),
1028-
connectionToDoom: _parser.Connection
1029-
);
1067+
onSuccess: () => WriteBytes(ReadOnlySpan<byte>.Empty, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion, array)
1068+
);
10301069
}
10311070

10321071
// Dumps contents of buffer to SNI for network write.

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/SSPI/ManagedSSPIContextProvider.cs

+10-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#if !NETFRAMEWORK && !NET7_0_OR_GREATER
22

33
using System;
4+
using System.Buffers;
45
using Microsoft.Data.SqlClient.SNI;
56

67
#nullable enable
@@ -11,13 +12,18 @@ internal sealed class ManagedSSPIContextProvider : SSPIContextProvider
1112
{
1213
private SspiClientContextStatus? _sspiClientContextStatus;
1314

14-
internal override void GenerateSspiClientContext(ReadOnlyMemory<byte> received, ref byte[] sendBuff, ref uint sendLength, byte[][] _sniSpnBuffer)
15+
internal override IMemoryOwner<byte> GenerateSspiClientContext(ReadOnlyMemory<byte> received, byte[][] _sniSpnBuffer)
1516
{
1617
_sspiClientContextStatus ??= new SspiClientContextStatus();
1718

18-
SNIProxy.GenSspiClientContext(_sspiClientContextStatus, received, ref sendBuff, _sniSpnBuffer);
19-
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _physicalStateObj.SessionId);
20-
sendLength = (uint)(sendBuff != null ? sendBuff.Length : 0);
19+
try
20+
{
21+
return SNIProxy.GenSspiClientContext(_sspiClientContextStatus, received, _sniSpnBuffer);
22+
}
23+
finally
24+
{
25+
SqlClientEventSource.Log.TryTraceEvent("TdsParserStateObjectManaged.GenerateSspiClientContext | Info | Session Id {0}", _physicalStateObj.SessionId);
26+
}
2127
}
2228
}
2329
}

0 commit comments

Comments
 (0)