Skip to content

Commit

Permalink
Improve AddImpl and SubtractImpl with Avx512
Browse files Browse the repository at this point in the history
  • Loading branch information
benaadams committed Dec 25, 2024
1 parent 53d6e10 commit b4e1355
Showing 1 changed file with 59 additions and 45 deletions.
104 changes: 59 additions & 45 deletions src/Nethermind.Int256/UInt256.cs
Original file line number Diff line number Diff line change
Expand Up @@ -406,24 +406,32 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
{
if (Avx2.IsSupported)
{
var av = Unsafe.As<UInt256,Vector256<ulong>>(ref Unsafe.AsRef(in a));
var bv = Unsafe.As<UInt256,Vector256<ulong>>(ref Unsafe.AsRef(in b));
Vector256<ulong> av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
Vector256<ulong> bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));

var result = Avx2.Add(av, bv);

var carryFromBothHighBits = Avx2.And(av, bv);
var eitherHighBit = Avx2.Or(av, bv);
var highBitNotInResult = Avx2.AndNot(result, eitherHighBit);
Vector256<ulong> result = Avx2.Add(av, bv);
Vector256<ulong> vCarry;
if (Avx512F.VL.IsSupported)
{
vCarry = Avx512F.VL.CompareLessThan(result, av);
}
else
{
// Work around for missing Vector256.CompareLessThan
Vector256<ulong> carryFromBothHighBits = Avx2.And(av, bv);
Vector256<ulong> eitherHighBit = Avx2.Or(av, bv);
Vector256<ulong> highBitNotInResult = Avx2.AndNot(result, eitherHighBit);

// Set high bits where carry occurs
var vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult);
// Set high bits where carry occurs
vCarry = Avx2.Or(carryFromBothHighBits, highBitNotInResult);
}
// Move carry from Vector space to int
var carry = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCarry));
int carry = Avx.MoveMask(vCarry.AsDouble());

// All bits set will cascade another carry when carry is added to it
var vCascade = Avx2.CompareEqual(result, Vector256<ulong>.AllBitsSet);
Vector256<ulong> vCascade = Avx2.CompareEqual(result, Vector256<ulong>.AllBitsSet);
// Move cascade from Vector space to int
var cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
int cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));

// Use ints to work out the Vector cross lane cascades
// Move carry to next bit and add cascade
Expand All @@ -434,12 +442,12 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
cascade &= 0x0f;

// Lookup the carries to broadcast to the Vectors
var cascadedCarries = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
Vector256<ulong> cascadedCarries = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);

// Mark res as initalized so we can use it as left said of ref assignment
// Mark res as initialized so we can use it as left said of ref assignment
Unsafe.SkipInit(out res);
// Add the cascadedCarries to the result
Unsafe.As<UInt256,Vector256<ulong>>(ref res) = Avx2.Add(result, cascadedCarries);
Unsafe.As<UInt256, Vector256<ulong>>(ref res) = Avx2.Add(result, cascadedCarries);

return (carry & 0b1_0000) != 0;
}
Expand All @@ -458,7 +466,6 @@ public static bool AddImpl(in UInt256 a, in UInt256 b, out UInt256 res)
// Debug.Assert((BigInteger)res == ((BigInteger)a + (BigInteger)b) % ((BigInteger)1 << 256));
// #endif
}

public void Add(in UInt256 a, out UInt256 res) => Add(this, a, out res);

/// <summary>
Expand Down Expand Up @@ -665,7 +672,7 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256
int uLen = 0;
for (int i = length - 1; i >= 0; i--)
{
if (Unsafe.Add(ref u,i) != 0)
if (Unsafe.Add(ref u, i) != 0)
{
uLen = i + 1;
break;
Expand Down Expand Up @@ -730,13 +737,13 @@ private static void Udivrem(ref ulong quot, ref ulong u, int length, in UInt256
goto r3;
}

r3:
r3:
rem2 = Rsh(un[2], shift) | Lsh(un[3], 64 - shift);
r2:
r2:
rem1 = Rsh(un[1], shift) | Lsh(un[2], 64 - shift);
r1:
r1:
rem0 = Rsh(un[0], shift) | Lsh(un[1], 64 - shift);
r0:
r0:

rem = new UInt256(rem0, rem1, rem2, rem3);
}
Expand Down Expand Up @@ -879,25 +886,32 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)
{
if (Avx2.IsSupported)
{
var av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
var bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));

var result = Avx2.Subtract(av, bv);
// Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned
var resultSigned = Avx2.Xor(result, Vector256.Create<ulong>(0x8000_0000_0000_0000));
var avSigned = Avx2.Xor(av, Vector256.Create<ulong>(0x8000_0000_0000_0000));
Vector256<ulong> av = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in a));
Vector256<ulong> bv = Unsafe.As<UInt256, Vector256<ulong>>(ref Unsafe.AsRef(in b));

// Which vectors need to borrow from the next
var vBorrow = Avx2.CompareGreaterThan(Unsafe.As<Vector256<ulong>, Vector256<long>>(ref resultSigned),
Unsafe.As<Vector256<ulong>, Vector256<long>>(ref avSigned));
Vector256<ulong> result = Avx2.Subtract(av, bv);
Vector256<ulong> vBorrow;
if (Avx512F.VL.IsSupported)
{
vBorrow = Avx512F.VL.CompareGreaterThan(result, av);
}
else
{
// Invert top bits as Avx2.CompareGreaterThan is only available for longs, not unsigned
Vector256<ulong> resultSigned = Avx2.Xor(result, Vector256.Create<ulong>(0x8000_0000_0000_0000));
Vector256<ulong> avSigned = Avx2.Xor(av, Vector256.Create<ulong>(0x8000_0000_0000_0000));

// Which vectors need to borrow from the next
vBorrow = Avx2.CompareGreaterThan(Unsafe.As<Vector256<ulong>, Vector256<long>>(ref resultSigned),
Unsafe.As<Vector256<ulong>, Vector256<long>>(ref avSigned)).AsUInt64();
}
// Move borrow from Vector space to int
var borrow = Avx.MoveMask(Unsafe.As<Vector256<long>, Vector256<double>>(ref vBorrow));
int borrow = Avx.MoveMask(vBorrow.AsDouble());

// All zeros will cascade another borrow when borrow is subtracted from it
var vCascade = Avx2.CompareEqual(result, Vector256<ulong>.Zero);
Vector256<ulong> vCascade = Avx2.CompareEqual(result, Vector256<ulong>.Zero);
// Move cascade from Vector space to int
var cascade = Avx.MoveMask(Unsafe.As<Vector256<ulong>, Vector256<double>>(ref vCascade));
int cascade = Avx.MoveMask(vCascade.AsDouble());

// Use ints to work out the Vector cross lane cascades
// Move borrow to next bit and add cascade
Expand All @@ -908,9 +922,9 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res)
cascade &= 0x0f;

// Lookup the borrows to broadcast to the Vectors
var cascadedBorrows = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);
Vector256<ulong> cascadedBorrows = Unsafe.Add(ref Unsafe.As<byte, Vector256<ulong>>(ref MemoryMarshal.GetReference(s_broadcastLookup)), cascade);

// Mark res as initalized so we can use it as left said of ref assignment
// Mark res as initialized so we can use it as left said of ref assignment
Unsafe.SkipInit(out res);
// Subtract the cascadedBorrows from the result
Unsafe.As<UInt256, Vector256<ulong>>(ref res) = Avx2.Subtract(result, cascadedBorrows);
Expand Down Expand Up @@ -1315,15 +1329,15 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res)
a = Rsh(res.u0, 64 - n);
z0 = Lsh(res.u0, n);

sh64:
sh64:
b = Rsh(res.u1, 64 - n);
z1 = Lsh(res.u1, n) | a;

sh128:
sh128:
a = Rsh(res.u2, 64 - n);
z2 = Lsh(res.u2, n) | b;

sh192:
sh192:
z3 = Lsh(res.u3, n) | a;

res = new UInt256(z0, z1, z2, z3);
Expand Down Expand Up @@ -1425,15 +1439,15 @@ public static void Rsh(in UInt256 x, int n, out UInt256 res)
a = Lsh(res.u3, 64 - n);
z3 = Rsh(res.u3, n);

sh64:
sh64:
b = Lsh(res.u2, 64 - n);
z2 = Rsh(res.u2, n) | a;

sh128:
sh128:
a = Lsh(res.u1, 64 - n);
z1 = Rsh(res.u1, n) | b;

sh192:
sh192:
z0 = Rsh(res.u0, n) | a;

res = new UInt256(z0, z1, z2, z3);
Expand Down Expand Up @@ -1923,13 +1937,13 @@ public static bool TryParse(in ReadOnlySpan<char> value, NumberStyles style, IFo
public TypeCode GetTypeCode() => TypeCode.Object;
public bool ToBoolean(IFormatProvider? provider) => !IsZero;
public byte ToByte(IFormatProvider? provider) => System.Convert.ToByte(ToDecimal(provider), provider);
public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider);
public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider);
public char ToChar(IFormatProvider? provider) => System.Convert.ToChar(ToDecimal(provider), provider);
public DateTime ToDateTime(IFormatProvider? provider) => System.Convert.ToDateTime(ToDecimal(provider), provider);
public decimal ToDecimal(IFormatProvider? provider) => (decimal)this;
public double ToDouble(IFormatProvider? provider) => (double)this;
public short ToInt16(IFormatProvider? provider) => System.Convert.ToInt16(ToDecimal(provider), provider);
public int ToInt32(IFormatProvider? provider) => System.Convert.ToInt32(ToDecimal(provider), provider);
public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider);
public long ToInt64(IFormatProvider? provider) => System.Convert.ToInt64(ToDecimal(provider), provider);
public sbyte ToSByte(IFormatProvider? provider) => System.Convert.ToSByte(ToDecimal(provider), provider);
public float ToSingle(IFormatProvider? provider) => System.Convert.ToSingle(ToDouble(provider), provider);
public string ToString(IFormatProvider? provider) => ((BigInteger)this).ToString(provider);
Expand Down

0 comments on commit b4e1355

Please sign in to comment.