diff --git a/crypto/src/BouncyCastle.Crypto.csproj b/crypto/src/BouncyCastle.Crypto.csproj index 0b40b210c..a388099d4 100644 --- a/crypto/src/BouncyCastle.Crypto.csproj +++ b/crypto/src/BouncyCastle.Crypto.csproj @@ -2,6 +2,7 @@ net6.0;netstandard2.0;net461 + 12 Org.BouncyCastle ..\..\BouncyCastle.NET.snk true diff --git a/crypto/src/crypto/engines/AesEngine_X86.cs b/crypto/src/crypto/engines/AesEngine_X86.cs index 274fe327d..42f2bb145 100644 --- a/crypto/src/crypto/engines/AesEngine_X86.cs +++ b/crypto/src/crypto/engines/AesEngine_X86.cs @@ -18,17 +18,16 @@ public struct AesEngine_X86 { public static bool IsSupported => Org.BouncyCastle.Runtime.Intrinsics.X86.Aes.IsEnabled; - private static Vector128[] CreateRoundKeys(ReadOnlySpan key, bool forEncryption) + private static void CreateRoundKeys(ReadOnlySpan key, bool forEncryption, Span> K, out int length) { - Vector128[] K; - switch (key.Length) { case 16: { ReadOnlySpan rcon = stackalloc byte[]{ 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36 }; - K = new Vector128[11]; + length = 11; + K = K[..length]; var s = Load128(key[..16]); K[0] = s; @@ -47,7 +46,8 @@ private static Vector128[] CreateRoundKeys(ReadOnlySpan key, bool fo } case 24: { - K = new Vector128[13]; + length = 13; + K = K[..length]; var s1 = Load128(key[..16]); var s2 = Load64(key[16..24]).ToVector128(); @@ -93,7 +93,8 @@ private static Vector128[] CreateRoundKeys(ReadOnlySpan key, bool fo } case 32: { - K = new Vector128[15]; + length = 15; + K = K[..length]; var s1 = Load128(key[..16]); var s2 = Load128(key[16..32]); @@ -134,15 +135,19 @@ private static Vector128[] CreateRoundKeys(ReadOnlySpan key, bool fo K[i] = Aes.InverseMixColumns(K[i]); } - Array.Reverse(K); + K.Reverse(); } - - return K; } private enum Mode { DEC_128, DEC_192, DEC_256, ENC_128, ENC_192, ENC_256, UNINITIALIZED }; - private Vector128[] m_roundKeys = null; + struct Keys + { + public Vector128 k0, k1, k2, k3, k4, k5, k6, k7, k8, k9, k10, k11, k12, k13, k14; + } + private Keys keys; + private int keysLength = 15; + private Span> m_roundKeys => MemoryMarshal.Cast>(MemoryMarshal.CreateSpan(ref keys, 1))[..keysLength]; private Mode m_mode = Mode.UNINITIALIZED; public AesEngine_X86() @@ -163,7 +168,9 @@ public void Init(bool forEncryption, ICipherParameters parameters) throw new ArgumentException("invalid type: " + Platform.GetTypeName(parameters), nameof(parameters)); } - m_roundKeys = CreateRoundKeys(keyParameter.Key, forEncryption); + keysLength = 15; + m_roundKeys.Fill(default); + CreateRoundKeys(keyParameter.Key, forEncryption, m_roundKeys, out keysLength); if (m_roundKeys.Length == 11) { @@ -250,7 +257,7 @@ private void ImplRounds( } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Decrypt128(Vector128[] roundKeys, ref Vector128 state) + private static void Decrypt128(ReadOnlySpan> roundKeys, ref Vector128 state) { var bounds = roundKeys[10]; var value = Sse2.Xor(state, roundKeys[0]); @@ -267,7 +274,7 @@ private static void Decrypt128(Vector128[] roundKeys, ref Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Decrypt192(Vector128[] roundKeys, ref Vector128 state) + private static void Decrypt192(ReadOnlySpan> roundKeys, ref Vector128 state) { var bounds = roundKeys[12]; var value = Sse2.Xor(state, roundKeys[0]); @@ -286,7 +293,7 @@ private static void Decrypt192(Vector128[] roundKeys, ref Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Decrypt256(Vector128[] roundKeys, ref Vector128 state) + private static void Decrypt256(ReadOnlySpan> roundKeys, ref Vector128 state) { var bounds = roundKeys[14]; var value = Sse2.Xor(state, roundKeys[0]); @@ -307,7 +314,7 @@ private static void Decrypt256(Vector128[] roundKeys, ref Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void DecryptFour128(Vector128[] rk, + private static void DecryptFour128(ReadOnlySpan> rk, ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) { var bounds = rk[10]; @@ -369,7 +376,7 @@ private static void DecryptFour128(Vector128[] rk, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void DecryptFour192(Vector128[] rk, + private static void DecryptFour192(ReadOnlySpan> rk, ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) { var bounds = rk[12]; @@ -441,7 +448,7 @@ private static void DecryptFour192(Vector128[] rk, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void DecryptFour256(Vector128[] rk, + private static void DecryptFour256(ReadOnlySpan> rk, ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) { var bounds = rk[14]; @@ -523,7 +530,7 @@ private static void DecryptFour256(Vector128[] rk, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Encrypt128(Vector128[] roundKeys, ref Vector128 state) + private static void Encrypt128(ReadOnlySpan> roundKeys, ref Vector128 state) { var bounds = roundKeys[10]; var value = Sse2.Xor(state, roundKeys[0]); @@ -540,7 +547,7 @@ private static void Encrypt128(Vector128[] roundKeys, ref Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Encrypt192(Vector128[] roundKeys, ref Vector128 state) + private static void Encrypt192(ReadOnlySpan> roundKeys, ref Vector128 state) { var bounds = roundKeys[12]; var value = Sse2.Xor(state, roundKeys[0]); @@ -559,7 +566,7 @@ private static void Encrypt192(Vector128[] roundKeys, ref Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void Encrypt256(Vector128[] roundKeys, ref Vector128 state) + private static void Encrypt256(ReadOnlySpan> roundKeys, ref Vector128 state) { var bounds = roundKeys[14]; var value = Sse2.Xor(state, roundKeys[0]); @@ -580,7 +587,7 @@ private static void Encrypt256(Vector128[] roundKeys, ref Vector128 } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void EncryptFour128(Vector128[] rk, + private static void EncryptFour128(ReadOnlySpan> rk, ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) { var bounds = rk[10]; @@ -642,7 +649,7 @@ private static void EncryptFour128(Vector128[] rk, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void EncryptFour192(Vector128[] rk, + private static void EncryptFour192(ReadOnlySpan> rk, ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) { var bounds = rk[12]; @@ -714,7 +721,7 @@ private static void EncryptFour192(Vector128[] rk, } [MethodImpl(MethodImplOptions.AggressiveInlining)] - private static void EncryptFour256(Vector128[] rk, + private static void EncryptFour256(ReadOnlySpan> rk, ref Vector128 s1, ref Vector128 s2, ref Vector128 s3, ref Vector128 s4) { var bounds = rk[14]; diff --git a/crypto/src/crypto/modes/GCMBlockCipher.cs b/crypto/src/crypto/modes/GCMBlockCipher.cs index b413088ab..2e388ae93 100644 --- a/crypto/src/crypto/modes/GCMBlockCipher.cs +++ b/crypto/src/crypto/modes/GCMBlockCipher.cs @@ -1,5 +1,6 @@ using System; #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +using System.Buffers; using System.Runtime.CompilerServices; #endif #if NETCOREAPP3_0_OR_GREATER @@ -988,7 +989,7 @@ public int DoFinal(Span output) long c = (long)(((totalLength * 8) + 127) >> 7); // Calculate the adjustment factor - byte[] H_c = new byte[16]; + Span H_c = stackalloc byte[16]; if (exp == null) { exp = new BasicGcmExponentiator(); @@ -1047,18 +1048,31 @@ public void Reset() Reset(true); } + static void Reset(ref T[] array, int size) + { + if (array is null || array.Length != size) + { + array = new T[size]; + } + else + { + Arrays.Fill(array, default); + } + } + private void Reset(bool clearMac) { // note: we do not reset the nonce. - S = new byte[BlockSize]; - S_at = new byte[BlockSize]; - S_atPre = new byte[BlockSize]; - atBlock = new byte[BlockSize]; + Reset(ref S, BlockSize); + Reset(ref S_at, BlockSize); + Reset(ref S_atPre, BlockSize); + Reset(ref atBlock, BlockSize); atBlockPos = 0; atLength = 0; atLengthPre = 0; - counter = Arrays.Clone(J0); + Reset(ref counter, BlockSize); + J0.CopyTo(counter, 0); counter32 = Pack.BE_To_UInt32(counter, 12); blocksRemaining = uint.MaxValue - 1; bufOff = 0; diff --git a/crypto/src/crypto/modes/gcm/BasicGcmExponentiator.cs b/crypto/src/crypto/modes/gcm/BasicGcmExponentiator.cs index b24550c8c..07d3e6a5c 100644 --- a/crypto/src/crypto/modes/gcm/BasicGcmExponentiator.cs +++ b/crypto/src/crypto/modes/gcm/BasicGcmExponentiator.cs @@ -12,8 +12,15 @@ public void Init(byte[] x) { GcmUtilities.AsFieldElement(x, out this.x); } - +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + public void ExponentiateX(long pow, byte[] output) + { + ExponentiateX(pow, output.AsSpan()); + } + public void ExponentiateX(long pow, Span output) +#else public void ExponentiateX(long pow, byte[] output) +#endif { GcmUtilities.FieldElement y; GcmUtilities.One(out y); diff --git a/crypto/src/crypto/modes/gcm/GcmUtilities.cs b/crypto/src/crypto/modes/gcm/GcmUtilities.cs index ef6ae62c3..f31963136 100644 --- a/crypto/src/crypto/modes/gcm/GcmUtilities.cs +++ b/crypto/src/crypto/modes/gcm/GcmUtilities.cs @@ -39,6 +39,15 @@ internal static void AsBytes(ulong x0, ulong x1, byte[] z) Pack.UInt64_To_BE(x1, z, 8); } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void AsBytes(ulong x0, ulong x1, Span z) + { + Pack.UInt64_To_BE(x0, z, 0); + Pack.UInt64_To_BE(x1, z, 8); + } +#endif + #if NETSTANDARD1_0_OR_GREATER || NETCOREAPP1_0_OR_GREATER [MethodImpl(MethodImplOptions.AggressiveInlining)] #endif @@ -47,6 +56,14 @@ internal static void AsBytes(ref FieldElement x, byte[] z) AsBytes(x.n0, x.n1, z); } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void AsBytes(ref FieldElement x, Span z) + { + AsBytes(x.n0, x.n1, z); + } +#endif + #if NETSTANDARD1_0_OR_GREATER || NETCOREAPP1_0_OR_GREATER [MethodImpl(MethodImplOptions.AggressiveInlining)] #endif @@ -56,6 +73,15 @@ internal static void AsFieldElement(byte[] x, out FieldElement z) z.n1 = Pack.BE_To_UInt64(x, 8); } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + [MethodImpl(MethodImplOptions.AggressiveInlining)] + internal static void AsFieldElement(ReadOnlySpan x, out FieldElement z) + { + z.n0 = Pack.BE_To_UInt64(x, 0); + z.n1 = Pack.BE_To_UInt64(x, 8); + } +#endif + internal static void DivideP(ref FieldElement x, out FieldElement z) { ulong x0 = x.n0, x1 = x.n1; @@ -72,6 +98,15 @@ internal static void Multiply(byte[] x, byte[] y) Multiply(ref X, ref Y); AsBytes(ref X, x); } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + internal static void Multiply(Span x, ReadOnlySpan y) + { + AsFieldElement(x, out FieldElement X); + AsFieldElement(y, out FieldElement Y); + Multiply(ref X, ref Y); + AsBytes(ref X, x); + } +#endif internal static void Multiply(ref FieldElement x, ref FieldElement y) { diff --git a/crypto/src/crypto/modes/gcm/IGcmExponentiator.cs b/crypto/src/crypto/modes/gcm/IGcmExponentiator.cs index bd6c07363..22819f782 100644 --- a/crypto/src/crypto/modes/gcm/IGcmExponentiator.cs +++ b/crypto/src/crypto/modes/gcm/IGcmExponentiator.cs @@ -6,6 +6,9 @@ namespace Org.BouncyCastle.Crypto.Modes.Gcm public interface IGcmExponentiator { void Init(byte[] x); - void ExponentiateX(long pow, byte[] output); - } + void ExponentiateX(long pow, byte[] output); +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + void ExponentiateX(long pow, Span output); +#endif + } } diff --git a/crypto/src/crypto/modes/gcm/Tables1kGcmExponentiator.cs b/crypto/src/crypto/modes/gcm/Tables1kGcmExponentiator.cs index 417e0b636..cdddcc8fc 100644 --- a/crypto/src/crypto/modes/gcm/Tables1kGcmExponentiator.cs +++ b/crypto/src/crypto/modes/gcm/Tables1kGcmExponentiator.cs @@ -22,7 +22,15 @@ public void Init(byte[] x) lookupPowX2.Add(y); } +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER public void ExponentiateX(long pow, byte[] output) + { + ExponentiateX(pow, output.AsSpan()); + } + public void ExponentiateX(long pow, Span output) +#else + public void ExponentiateX(long pow, byte[] output) +#endif { GcmUtilities.FieldElement y; GcmUtilities.One(out y); diff --git a/crypto/src/tls/DtlsRecordLayer.cs b/crypto/src/tls/DtlsRecordLayer.cs index fe3b58d41..e96d3a324 100644 --- a/crypto/src/tls/DtlsRecordLayer.cs +++ b/crypto/src/tls/DtlsRecordLayer.cs @@ -1,4 +1,7 @@ using System; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +using System.Buffers; +#endif using System.Diagnostics; using System.IO; using System.Net.Sockets; @@ -388,59 +391,70 @@ internal int Receive(Span buffer, int waitMillis, DtlsRecordCallback recor Timeout timeout = Timeout.ForWaitMillis(waitMillis, currentTimeMillis); byte[] record = null; - while (waitMillis >= 0) + try { - if (null != m_retransmitTimeout && m_retransmitTimeout.RemainingMillis(currentTimeMillis) < 1) + while (waitMillis >= 0) { - m_retransmit = null; - m_retransmitEpoch = null; - m_retransmitTimeout = null; - } + if (null != m_retransmitTimeout && m_retransmitTimeout.RemainingMillis(currentTimeMillis) < 1) + { + m_retransmit = null; + m_retransmitEpoch = null; + m_retransmitTimeout = null; + } - if (Timeout.HasExpired(m_heartbeatTimeout, currentTimeMillis)) - { - if (null != m_heartbeatInFlight) - throw new TlsTimeoutException("Heartbeat timed out"); + if (Timeout.HasExpired(m_heartbeatTimeout, currentTimeMillis)) + { + if (null != m_heartbeatInFlight) + throw new TlsTimeoutException("Heartbeat timed out"); - this.m_heartbeatInFlight = HeartbeatMessage.Create(m_context, - HeartbeatMessageType.heartbeat_request, m_heartbeat.GeneratePayload()); - this.m_heartbeatTimeout = new Timeout(m_heartbeat.TimeoutMillis, currentTimeMillis); + this.m_heartbeatInFlight = HeartbeatMessage.Create(m_context, + HeartbeatMessageType.heartbeat_request, m_heartbeat.GeneratePayload()); + this.m_heartbeatTimeout = new Timeout(m_heartbeat.TimeoutMillis, currentTimeMillis); - this.m_heartbeatResendMillis = TlsUtilities.GetHandshakeResendTimeMillis(m_peer); - this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis); + this.m_heartbeatResendMillis = TlsUtilities.GetHandshakeResendTimeMillis(m_peer); + this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis); - SendHeartbeatMessage(m_heartbeatInFlight); - } - else if (Timeout.HasExpired(m_heartbeatResendTimeout, currentTimeMillis)) - { - this.m_heartbeatResendMillis = DtlsReliableHandshake.BackOff(m_heartbeatResendMillis); - this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis); + SendHeartbeatMessage(m_heartbeatInFlight); + } + else if (Timeout.HasExpired(m_heartbeatResendTimeout, currentTimeMillis)) + { + this.m_heartbeatResendMillis = DtlsReliableHandshake.BackOff(m_heartbeatResendMillis); + this.m_heartbeatResendTimeout = new Timeout(m_heartbeatResendMillis, currentTimeMillis); - SendHeartbeatMessage(m_heartbeatInFlight); - } + SendHeartbeatMessage(m_heartbeatInFlight); + } - waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatTimeout, currentTimeMillis); - waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatResendTimeout, currentTimeMillis); + waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatTimeout, currentTimeMillis); + waitMillis = Timeout.ConstrainWaitMillis(waitMillis, m_heartbeatResendTimeout, currentTimeMillis); - // NOTE: Guard against bad logic giving a negative value - if (waitMillis < 0) - { - waitMillis = 1; - } + // NOTE: Guard against bad logic giving a negative value + if (waitMillis < 0) + { + waitMillis = 1; + } - int receiveLimit = m_transport.GetReceiveLimit(); - if (null == record || record.Length < receiveLimit) - { - record = new byte[receiveLimit]; - } + int receiveLimit = m_transport.GetReceiveLimit(); + if (null == record || record.Length < receiveLimit) + { + // record = new byte[receiveLimit]; + if (record is not null) + ArrayPool.Shared.Return(record); + record = ArrayPool.Shared.Rent(receiveLimit); + } - int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); - int processed = ProcessRecord(received, record, buffer, recordCallback); - if (processed >= 0) - return processed; + int received = ReceiveRecord(record, 0, receiveLimit, waitMillis); + int processed = ProcessRecord(received, record, buffer, recordCallback); + if (processed >= 0) + return processed; - currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); - waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis); + currentTimeMillis = DateTimeUtilities.CurrentUnixMs(); + waitMillis = Timeout.GetWaitMillis(timeout, currentTimeMillis); + } + } + finally + { + if (record is not null) + ArrayPool.Shared.Return(record); } return -1; @@ -1135,8 +1149,12 @@ private void SendRecord(short contentType, byte[] buf, int off, int len) int recordHeaderLength = m_writeEpoch.RecordHeaderLengthWrite; #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER - TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, - recordVersion, recordHeaderLength, buffer); + using var encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, + recordVersion, recordHeaderLength, buffer +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + , ArrayPool.Shared +#endif + ); #else TlsEncodeResult encoded = m_writeEpoch.Cipher.EncodePlaintext(macSequenceNumber, contentType, recordVersion, recordHeaderLength, buf, off, len); diff --git a/crypto/src/tls/RecordStream.cs b/crypto/src/tls/RecordStream.cs index 739ee8d70..12903a145 100644 --- a/crypto/src/tls/RecordStream.cs +++ b/crypto/src/tls/RecordStream.cs @@ -293,7 +293,7 @@ internal void WriteRecord(short contentType, byte[] plaintext, int plaintextOffs long seqNo = m_writeSeqNo.NextValue(AlertDescription.internal_error); ProtocolVersion recordVersion = m_writeVersion; - TlsEncodeResult encoded = m_writeCipher.EncodePlaintext(seqNo, contentType, recordVersion, + using var encoded = m_writeCipher.EncodePlaintext(seqNo, contentType, recordVersion, RecordFormat.FragmentOffset, plaintext, plaintextOffset, plaintextLength); int ciphertextLength = encoded.len - RecordFormat.FragmentOffset; @@ -340,8 +340,12 @@ internal void WriteRecord(short contentType, ReadOnlySpan plaintext) long seqNo = m_writeSeqNo.NextValue(AlertDescription.internal_error); ProtocolVersion recordVersion = m_writeVersion; - TlsEncodeResult encoded = m_writeCipher.EncodePlaintext(seqNo, contentType, recordVersion, - RecordFormat.FragmentOffset, plaintext); + using var encoded = m_writeCipher.EncodePlaintext(seqNo, contentType, recordVersion, + RecordFormat.FragmentOffset, plaintext +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + , System.Buffers.ArrayPool.Shared +#endif + ); int ciphertextLength = encoded.len - RecordFormat.FragmentOffset; TlsUtilities.CheckUint16(ciphertextLength); diff --git a/crypto/src/tls/crypto/TlsCipher.cs b/crypto/src/tls/crypto/TlsCipher.cs index 5665cbb7e..45eafacfb 100644 --- a/crypto/src/tls/crypto/TlsCipher.cs +++ b/crypto/src/tls/crypto/TlsCipher.cs @@ -1,4 +1,7 @@ using System; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +using System.Buffers; +#endif using System.IO; namespace Org.BouncyCastle.Tls.Crypto @@ -44,7 +47,7 @@ TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion r #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER // TODO[api] Add a parameter for how much (D)TLSInnerPlaintext padding to add TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, - int headerAllocation, ReadOnlySpan plaintext); + int headerAllocation, ReadOnlySpan plaintext, ArrayPool pool = null); #endif /// Decode the passed in ciphertext using the current bulk cipher. diff --git a/crypto/src/tls/crypto/TlsEncodeResult.cs b/crypto/src/tls/crypto/TlsEncodeResult.cs index 963e4563a..8fe9c9cac 100644 --- a/crypto/src/tls/crypto/TlsEncodeResult.cs +++ b/crypto/src/tls/crypto/TlsEncodeResult.cs @@ -1,19 +1,45 @@ -using System; +#nullable enable + +using System; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER +using System.Buffers; +#endif +using System.Threading; namespace Org.BouncyCastle.Tls.Crypto { - public sealed class TlsEncodeResult + public struct TlsEncodeResult: IDisposable { - public readonly byte[] buf; + public byte[] buf; public readonly int off, len; public readonly short recordType; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + public readonly ArrayPool? pool; +#endif - public TlsEncodeResult(byte[] buf, int off, int len, short recordType) - { + public TlsEncodeResult(byte[] buf, int off, int len, short recordType +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + , ArrayPool? pool = null +#endif + ) { this.buf = buf; this.off = off; this.len = len; this.recordType = recordType; +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + this.pool = pool; +#endif + } + + public void Dispose() + { + byte[]? killBuf = Interlocked.Exchange(ref this.buf, null!); +#if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER + if (killBuf is not null) + { + this.pool?.Return(killBuf); + } +#endif } } } diff --git a/crypto/src/tls/crypto/TlsNullNullCipher.cs b/crypto/src/tls/crypto/TlsNullNullCipher.cs index 13fe092f7..b448e2a28 100644 --- a/crypto/src/tls/crypto/TlsNullNullCipher.cs +++ b/crypto/src/tls/crypto/TlsNullNullCipher.cs @@ -33,11 +33,13 @@ public TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVe #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER public TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, - int headerAllocation, ReadOnlySpan plaintext) + int headerAllocation, ReadOnlySpan plaintext, + System.Buffers.ArrayPool pool = null) { - byte[] result = new byte[headerAllocation + plaintext.Length]; + int bufferSize = headerAllocation + plaintext.Length; + byte[] result = pool?.Rent(bufferSize) ?? new byte[bufferSize]; plaintext.CopyTo(result.AsSpan(headerAllocation)); - return new TlsEncodeResult(result, 0, result.Length, contentType); + return new TlsEncodeResult(result, 0, bufferSize, contentType, pool); } #endif diff --git a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs index 9fc9d3b9e..ea68595db 100644 --- a/crypto/src/tls/crypto/impl/TlsAeadCipher.cs +++ b/crypto/src/tls/crypto/impl/TlsAeadCipher.cs @@ -256,7 +256,8 @@ public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, Pr #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, - int headerAllocation, ReadOnlySpan plaintext) + int headerAllocation, ReadOnlySpan plaintext, + System.Buffers.ArrayPool pool = null) { byte[] nonce = new byte[m_encryptNonce.Length + m_record_iv_length]; @@ -286,51 +287,60 @@ public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, Pr int encryptionLength = m_encryptCipher.GetOutputSize(innerPlaintextLength); int ciphertextLength = m_record_iv_length + encryptionLength; - byte[] output = new byte[headerAllocation + ciphertextLength]; + int bufferSize = headerAllocation + ciphertextLength; + byte[] output = pool?.Rent(bufferSize) ?? new byte[bufferSize]; int outputPos = headerAllocation; - - if (m_record_iv_length != 0) + short recordType; + try { - Array.Copy(nonce, nonce.Length - m_record_iv_length, output, outputPos, m_record_iv_length); - outputPos += m_record_iv_length; - } + if (m_record_iv_length != 0) + { + Array.Copy(nonce, nonce.Length - m_record_iv_length, output, outputPos, m_record_iv_length); + outputPos += m_record_iv_length; + } - short recordType = contentType; - if (m_encryptUseInnerPlaintext) - { - recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid; - } + recordType = contentType; + if (m_encryptUseInnerPlaintext) + { + recordType = m_isTlsV13 ? ContentType.application_data : ContentType.tls12_cid; + } - byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, - innerPlaintextLength, m_encryptConnectionID); + byte[] additionalData = GetAdditionalData(seqNo, recordType, recordVersion, ciphertextLength, + innerPlaintextLength, m_encryptConnectionID); - try - { - plaintext.CopyTo(output.AsSpan(outputPos)); - if (m_encryptUseInnerPlaintext) + try { - output[outputPos + plaintext.Length] = (byte)contentType; + plaintext.CopyTo(output.AsSpan(outputPos)); + if (m_encryptUseInnerPlaintext) + { + output[outputPos + plaintext.Length] = (byte)contentType; + } + + outputPos += m_encryptCipher.DoFinal(additionalData, output, outputPos, innerPlaintextLength, output, + outputPos); + } + catch (IOException) + { + throw; + } + catch (Exception e) + { + throw new TlsFatalAlert(AlertDescription.internal_error, e); } - outputPos += m_encryptCipher.DoFinal(additionalData, output, outputPos, innerPlaintextLength, output, - outputPos); + if (outputPos != bufferSize) + { + // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction. + throw new TlsFatalAlert(AlertDescription.internal_error); + } } - catch (IOException) + catch { + pool?.Return(output); throw; } - catch (Exception e) - { - throw new TlsFatalAlert(AlertDescription.internal_error, e); - } - if (outputPos != output.Length) - { - // NOTE: The additional data mechanism for AEAD ciphers requires exact output size prediction. - throw new TlsFatalAlert(AlertDescription.internal_error); - } - - return new TlsEncodeResult(output, 0, output.Length, recordType); + return new TlsEncodeResult(output, 0, bufferSize, recordType, pool); } #endif diff --git a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs index b9b8b2a76..826c1d31e 100644 --- a/crypto/src/tls/crypto/impl/TlsBlockCipher.cs +++ b/crypto/src/tls/crypto/impl/TlsBlockCipher.cs @@ -292,7 +292,8 @@ public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, Pr #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, - int headerAllocation, ReadOnlySpan plaintext) + int headerAllocation, ReadOnlySpan plaintext, + System.Buffers.ArrayPool pool = null) { int blockSize = m_encryptCipher.GetBlockSize(); int macSize = m_writeMac.Size; @@ -321,57 +322,65 @@ public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, Pr totalSize += blockSize; } - byte[] outBuf = new byte[headerAllocation + totalSize]; + int bufferSize = headerAllocation + totalSize; int outOff = headerAllocation; + short recordType; + byte[] outBuf = pool?.Rent(bufferSize) ?? new byte[bufferSize]; + try { + if (m_useExplicitIV) + { + // Technically the explicit IV will be the encryption of this nonce + byte[] explicitIV = m_cryptoParams.NonceGenerator.GenerateNonce(blockSize); + Array.Copy(explicitIV, 0, outBuf, outOff, blockSize); + outOff += blockSize; + } - if (m_useExplicitIV) - { - // Technically the explicit IV will be the encryption of this nonce - byte[] explicitIV = m_cryptoParams.NonceGenerator.GenerateNonce(blockSize); - Array.Copy(explicitIV, 0, outBuf, outOff, blockSize); - outOff += blockSize; - } + int innerPlaintextOffset = outOff; - int innerPlaintextOffset = outOff; + plaintext.CopyTo(outBuf.AsSpan(outOff)); + outOff += plaintext.Length; - plaintext.CopyTo(outBuf.AsSpan(outOff)); - outOff += plaintext.Length; + recordType = contentType; + if (m_encryptUseInnerPlaintext) + { + outBuf[outOff++] = (byte)contentType; + recordType = ContentType.tls12_cid; + } - short recordType = contentType; - if (m_encryptUseInnerPlaintext) - { - outBuf[outOff++] = (byte)contentType; - recordType = ContentType.tls12_cid; - } + if (!m_encryptThenMac) + { + byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, + outBuf.AsSpan(innerPlaintextOffset, innerPlaintextLength)); + mac.CopyTo(outBuf.AsSpan(outOff)); + outOff += mac.Length; + } - if (!m_encryptThenMac) - { - byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, - outBuf.AsSpan(innerPlaintextOffset, innerPlaintextLength)); - mac.CopyTo(outBuf.AsSpan(outOff)); - outOff += mac.Length; - } + byte padByte = (byte)(padding_length - 1); + for (int i = 0; i < padding_length; ++i) + { + outBuf[outOff++] = padByte; + } - byte padByte = (byte)(padding_length - 1); - for (int i = 0; i < padding_length; ++i) - { - outBuf[outOff++] = padByte; - } + m_encryptCipher.DoFinal(outBuf, headerAllocation, outOff - headerAllocation, outBuf, headerAllocation); - m_encryptCipher.DoFinal(outBuf, headerAllocation, outOff - headerAllocation, outBuf, headerAllocation); + if (m_encryptThenMac) + { + byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, + outBuf.AsSpan(headerAllocation, outOff - headerAllocation)); + Array.Copy(mac, 0, outBuf, outOff, mac.Length); + outOff += mac.Length; + } - if (m_encryptThenMac) + if (outOff != bufferSize) + throw new TlsFatalAlert(AlertDescription.internal_error); + } + catch { - byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, - outBuf.AsSpan(headerAllocation, outOff - headerAllocation)); - Array.Copy(mac, 0, outBuf, outOff, mac.Length); - outOff += mac.Length; + pool?.Return(outBuf); + throw; } - if (outOff != outBuf.Length) - throw new TlsFatalAlert(AlertDescription.internal_error); - - return new TlsEncodeResult(outBuf, 0, outBuf.Length, recordType); + return new TlsEncodeResult(outBuf, 0, bufferSize, recordType, pool); } #endif diff --git a/crypto/src/tls/crypto/impl/TlsNullCipher.cs b/crypto/src/tls/crypto/impl/TlsNullCipher.cs index 7c1bad6f7..80c2bb493 100644 --- a/crypto/src/tls/crypto/impl/TlsNullCipher.cs +++ b/crypto/src/tls/crypto/impl/TlsNullCipher.cs @@ -137,28 +137,38 @@ public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, Pr #if NETCOREAPP2_1_OR_GREATER || NETSTANDARD2_1_OR_GREATER public virtual TlsEncodeResult EncodePlaintext(long seqNo, short contentType, ProtocolVersion recordVersion, - int headerAllocation, ReadOnlySpan plaintext) + int headerAllocation, ReadOnlySpan plaintext, + System.Buffers.ArrayPool pool = null) { int macSize = m_writeMac.Size; // TODO[cid] If we support adding padding to DTLSInnerPlaintext, this will need review int innerPlaintextLength = plaintext.Length + (m_encryptUseInnerPlaintext ? 1 : 0); - byte[] ciphertext = new byte[headerAllocation + innerPlaintextLength + macSize]; - plaintext.CopyTo(ciphertext.AsSpan(headerAllocation)); + int bufferSize = headerAllocation + innerPlaintextLength + macSize; + byte[] ciphertext = pool?.Rent(bufferSize) ?? new byte[bufferSize]; + short recordType; + try { + plaintext.CopyTo(ciphertext.AsSpan(headerAllocation)); - short recordType = contentType; - if (m_encryptUseInnerPlaintext) + recordType = contentType; + if (m_encryptUseInnerPlaintext) + { + ciphertext[headerAllocation + plaintext.Length] = (byte)contentType; + recordType = ContentType.tls12_cid; + } + + byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, + ciphertext.AsSpan(headerAllocation, innerPlaintextLength)); + mac.CopyTo(ciphertext.AsSpan(headerAllocation + innerPlaintextLength)); + } + catch { - ciphertext[headerAllocation + plaintext.Length] = (byte)contentType; - recordType = ContentType.tls12_cid; + pool?.Return(ciphertext); + throw; } - byte[] mac = m_writeMac.CalculateMac(seqNo, recordType, m_encryptConnectionID, - ciphertext.AsSpan(headerAllocation, innerPlaintextLength)); - mac.CopyTo(ciphertext.AsSpan(headerAllocation + innerPlaintextLength)); - - return new TlsEncodeResult(ciphertext, 0, ciphertext.Length, recordType); + return new TlsEncodeResult(ciphertext, 0, bufferSize, recordType, pool); } #endif