Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use an IBufferWriter<byte> to write the outgoing SSPI blob #2452

Merged
merged 25 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1ec663f
Use an IBufferWriter<byte> to write the outgoing SSPI blob
twsouthwick Apr 8, 2024
7c3335e
Merge branch 'main' into sspi-writer
twsouthwick Jun 27, 2024
4c3c3f2
Merge remote-tracking branch 'upstream/main' into sspi-writer
twsouthwick Jul 10, 2024
d3bd2b9
fix
twsouthwick Jul 10, 2024
266cf7f
switch to span
twsouthwick May 7, 2024
7d99053
add return
twsouthwick Aug 19, 2024
7e4d15f
Merge remote-tracking branch 'upstream/main' into sspi-writer
twsouthwick Aug 19, 2024
09db047
revert
twsouthwick Aug 19, 2024
a49abc3
use return
twsouthwick Aug 20, 2024
a1703d8
inline
twsouthwick Aug 20, 2024
1781d92
Merge remote-tracking branch 'upstream/main' into sspi-writer
twsouthwick Nov 14, 2024
c41dec3
merge main
twsouthwick Nov 14, 2024
4def406
remove unneeded if/def
twsouthwick Nov 15, 2024
669cc6a
Merge remote-tracking branch 'origin/main' into sspi-writer
twsouthwick Jan 31, 2025
1ba12e6
react to ISniNativeMethods
twsouthwick Jan 31, 2025
c4d3dbe
add to strings.resx
twsouthwick Jan 31, 2025
b00a17c
revert other changes to string designer
twsouthwick Jan 31, 2025
43ec92d
move write methods that are the same to shared file
twsouthwick Jan 31, 2025
7b7ee36
make sure to use correct length
twsouthwick Feb 3, 2025
4f95969
Merge branch 'main' into sspi-writer
twsouthwick Feb 11, 2025
e1986ac
Merge remote-tracking branch 'origin/main' into sspi-writer
twsouthwick Feb 12, 2025
130cdef
put method back
twsouthwick Feb 12, 2025
f40013e
Add comment for pool
twsouthwick Feb 13, 2025
31e3ec9
Add note about file origin and editing restrictions
twsouthwick Feb 13, 2025
e17288c
Merge branch 'main' into sspi-writer
twsouthwick Feb 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ private static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down Expand Up @@ -471,16 +471,19 @@ internal static unsafe void SNIPacketSetData(SNIPacket packet, byte[] data, int
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
{
sendLength = (uint)outBuff.Length;

fixed (byte* pin_serverUserName = &serverUserName[0])
fixed (byte* pInBuff = inBuff)
fixed (byte* pOutBuff = outBuff)
{
return SNISecGenClientContextWrapper(
pConnectionObject,
pInBuff,
(uint)inBuff.Length,
OutBuff,
pOutBuff,
ref sendLength,
out _,
pin_serverUserName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,9 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlParameter.cs">
<Link>Microsoft\Data\SqlClient\SqlParameter.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ internal class SNIProxy
/// </summary>
/// <param name="sspiClientContextStatus">SSPI client context status</param>
/// <param name="receivedBuff">Receive buffer</param>
/// <param name="sendBuff">Send buffer</param>
/// <param name="sendWriter">Writer for send buffer</param>
/// <param name="serverName">Service Principal Name buffer</param>
/// <returns>SNI error code</returns>
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlyMemory<byte> receivedBuff, ref byte[] sendBuff, byte[][] serverName)
internal static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, ReadOnlySpan<byte> receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
{
// TODO: this should use ReadOnlyMemory all the way through
byte[] array = null;
Expand All @@ -46,10 +46,10 @@ internal static void GenSspiClientContext(SspiClientContextStatus sspiClientCont
receivedBuff.CopyTo(array);
}

GenSspiClientContext(sspiClientContextStatus, array, ref sendBuff, serverName);
GenSspiClientContext(sspiClientContextStatus, array, sendWriter, serverName);
}

private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, ref byte[] sendBuff, byte[][] serverName)
private static void GenSspiClientContext(SspiClientContextStatus sspiClientContextStatus, byte[] receivedBuff, IBufferWriter<byte> sendWriter, byte[][] serverName)
{
SafeDeleteContext securityContext = sspiClientContextStatus.SecurityContext;
ContextFlagsPal contextFlags = sspiClientContextStatus.ContextFlags;
Expand Down Expand Up @@ -103,10 +103,9 @@ private static void GenSspiClientContext(SspiClientContextStatus sspiClientConte
outSecurityBuffer.token = null;
}

sendBuff = outSecurityBuffer.token;
if (sendBuff == null)
if (outSecurityBuffer.token is { } token)
{
sendBuff = Array.Empty<byte>();
sendWriter.Write(token);
}

sspiClientContextStatus.SecurityContext = securityContext;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8425,8 +8425,7 @@ private void WriteLoginData(SqlLogin rec,
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -8596,8 +8595,8 @@ private void WriteLoginData(SqlLogin rec,
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -8652,7 +8651,7 @@ private void WriteLoginData(SqlLogin rec,

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@
<Compile Include="$(CommonSourceRoot)Microsoft\Data\Sql\SqlNotificationRequest.cs">
<Link>Microsoft\Data\Sql\SqlNotificationRequest.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\SqlObjectPool.cs">
<Link>Microsoft\Data\SqlClient\SqlObjectPool.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ArrayBufferWriter.cs">
<Link>Microsoft\Data\SqlClient\ArrayBufferWriter.cs</Link>
</Compile>
<Compile Include="$(CommonSourceRoot)Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs">
<Link>Microsoft\Data\SqlClient\ActiveDirectoryAuthenticationTimeoutRetryHelper.cs</Link>
</Compile>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ internal static extern unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] byte* pIn,
uint cbIn,
[In, Out] byte[] pOut,
[In, Out] byte* pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,7 @@ private static unsafe void SNIPacketSetData(SNIPacket pPacket, [In] byte* pbBuf,
private static unsafe uint SNISecGenClientContextWrapper(
[In] SNIHandle pConn,
[In, Out] ReadOnlySpan<byte> pIn,
[In, Out] byte[] pOut,
[In, Out] Span<byte> pOut,
[In] ref uint pcbOut,
[MarshalAsAttribute(UnmanagedType.Bool)] out bool pfDone,
byte* szServerInfo,
Expand All @@ -899,15 +899,16 @@ private static unsafe uint SNISecGenClientContextWrapper(
[MarshalAsAttribute(UnmanagedType.LPWStr)] string pwszPassword)
{
fixed (byte* pInPtr = pIn)
fixed (byte* pOutPtr = pOut)
{
switch (s_architecture)
{
case System.Runtime.InteropServices.Architecture.Arm64:
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
return SNINativeManagedWrapperARM64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X64:
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
return SNINativeManagedWrapperX64.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
case System.Runtime.InteropServices.Architecture.X86:
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOut, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
return SNINativeManagedWrapperX86.SNISecGenClientContextWrapper(pConn, pInPtr, (uint)pIn.Length, pOutPtr, ref pcbOut, out pfDone, szServerInfo, cbServerInfo, pwszUserName, pwszPassword);
default:
throw ADP.SNIPlatformNotSupported(s_architecture.ToString());
}
Expand Down Expand Up @@ -1378,14 +1379,16 @@ Int32[] passwordOffsets // Offset into data buffer where the password to be w
}
}

internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, byte[] OutBuff, ref uint sendLength, byte[] serverUserName)
internal static unsafe uint SNISecGenClientContext(SNIHandle pConnectionObject, ReadOnlySpan<byte> inBuff, Span<byte> outBuff, out uint sendLength, byte[] serverUserName)
{
sendLength = (uint)outBuff.Length;

fixed (byte* pin_serverUserName = &serverUserName[0])
{
return SNISecGenClientContextWrapper(
pConnectionObject,
inBuff,
OutBuff,
outBuff,
ref sendLength,
out bool _,
pin_serverUserName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9266,8 +9266,7 @@ private void WriteLoginData(SqlLogin rec,
int length,
int featureExOffset,
string clientInterfaceName,
byte[] outSSPIBuff,
uint outSSPILength)
ReadOnlySpan<byte> outSSPI)
{
try
{
Expand Down Expand Up @@ -9440,8 +9439,8 @@ private void WriteLoginData(SqlLogin rec,
WriteShort(offset, _physicalStateObj); // ibSSPI offset
if (rec.useSSPI)
{
WriteShort((int)outSSPILength, _physicalStateObj);
offset += (int)outSSPILength;
WriteShort(outSSPI.Length, _physicalStateObj);
offset += outSSPI.Length;
}
else
{
Expand Down Expand Up @@ -9500,7 +9499,7 @@ private void WriteLoginData(SqlLogin rec,

// send over SSPI data if we are using SSPI
if (rec.useSSPI)
_physicalStateObj.WriteByteArray(outSSPIBuff, (int)outSSPILength, 0);
_physicalStateObj.WriteByteSpan(outSSPI);

WriteString(rec.attachDBFilename, _physicalStateObj);
if (!rec.useSSPI && !(_connHandler._federatedAuthenticationInfoRequested || _connHandler._federatedAuthenticationRequested))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -930,9 +930,28 @@ internal void WriteByte(byte b)
// set byte in buffer and increment the counter for number of bytes used in the out buffer
_outBuff[_outBytesUsed++] = b;
}
internal Task WriteByteSpan(ReadOnlySpan<byte> span, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
{
return WriteBytes(span, span.Length, 0, canAccumulate, completion);
}

internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null)
{
return WriteBytes(ReadOnlySpan<byte>.Empty, len, offsetBuffer, canAccumulate, completion, b);
}

//
// Takes a span or a byte array and writes it to the buffer
// If you pass in a span and a null array then the span wil be used.
// If you pass in a non-null array then the array will be used and the span is ignored.
// if the span cannot be written into the current packet then the remaining contents of the span are copied to a
// new heap allocated array that will used to callback into the method to continue the write operation.
private Task WriteBytes(ReadOnlySpan<byte> b, int len, int offsetBuffer, bool canAccumulate = true, TaskCompletionSource<object> completion = null, byte[] array = null)
{
if (array != null)
{
b = new ReadOnlySpan<byte>(array, offsetBuffer, len);
}
try
{
TdsParser.ReliabilitySection.Assert("unreliable call to WriteByteArray"); // you need to setup for a thread abort somewhere before you call this method
Expand All @@ -949,7 +968,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu

int offset = offsetBuffer;

Debug.Assert(b.Length >= len, "Invalid length sent to WriteByteArray()!");
Debug.Assert(b.Length >= len, "Invalid length sent to WriteBytes()!");

// loop through and write the entire array
do
Expand All @@ -963,12 +982,17 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
int remainder = _outBuff.Length - _outBytesUsed;

// write the remainder
Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, remainder);
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, remainder);
ReadOnlySpan<byte> copyFrom = b.Slice(0, remainder);

Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length{copyFrom.Length:D} should be the same");

copyFrom.CopyTo(copyTo);

// handle counters
offset += remainder;
_outBytesUsed += remainder;
len -= remainder;
b = b.Slice(remainder, len);

Task packetTask = WritePacket(TdsEnums.SOFTFLUSH, canAccumulate);

Expand All @@ -981,18 +1005,35 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
completion = new TaskCompletionSource<object>();
task = completion.Task; // we only care about return from topmost call, so do not access Task property in other cases
}
WriteByteArraySetupContinuation(b, len, completion, offset, packetTask);

if (array == null)
{
byte[] tempArray = new byte[len];
Span<byte> copyTempTo = tempArray.AsSpan();

Debug.Assert(copyTempTo.Length == b.Length, $"copyTempTo.Length:{copyTempTo.Length} and copyTempFrom.Length:{b.Length:D} should be the same");

b.CopyTo(copyTempTo);
array = tempArray;
offset = 0;
}

WriteBytesSetupContinuation(array, len, completion, offset, packetTask);
return task;
}

}
else
{
//((stateObj._outBytesUsed + len) <= stateObj._outBuff.Length )
// Else the remainder of the string will fit into the buffer, so copy it into the
// buffer and then break out of the loop.

Buffer.BlockCopy(b, offset, _outBuff, _outBytesUsed, len);
Span<byte> copyTo = _outBuff.AsSpan(_outBytesUsed, len);
ReadOnlySpan<byte> copyFrom = b.Slice(0, len);

Debug.Assert(copyTo.Length == copyFrom.Length, $"copyTo.Length:{copyTo.Length} and copyFrom.Length:{copyFrom.Length:D} should be the same");

copyFrom.CopyTo(copyTo);

// handle out buffer bytes used counter
_outBytesUsed += len;
Expand Down Expand Up @@ -1021,7 +1062,7 @@ internal Task WriteByteArray(byte[] b, int len, int offsetBuffer, bool canAccumu
}

// This is in its own method to avoid always allocating the lambda in WriteByteArray
private void WriteByteArraySetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
private void WriteBytesSetupContinuation(byte[] b, int len, TaskCompletionSource<object> completion, int offset, Task packetTask)
{
AsyncHelper.ContinueTask(packetTask, completion,
() => WriteByteArray(b, len: len, offsetBuffer: offset, canAccumulate: false, completion: completion),
Expand Down
Loading
Loading