Skip to content

Commit 76d1203

Browse files
Improvements to the "Sum" SIMD algorithm (dotnet#1112)
1 parent 263a67b commit 76d1203

File tree

12 files changed

+297
-59
lines changed

12 files changed

+297
-59
lines changed

src/Microsoft.ML.CpuMath/AvxIntrinsics.cs

+91-20
Original file line numberDiff line numberDiff line change
@@ -1300,41 +1300,112 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
13001300
}
13011301
}
13021302

1303-
public static unsafe float SumU(ReadOnlySpan<float> src)
1303+
public static unsafe float Sum(ReadOnlySpan<float> src)
13041304
{
1305-
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1305+
fixed (float* pSrc = &MemoryMarshal.GetReference(src))
1306+
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
1307+
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
13061308
{
1307-
float* pSrcEnd = psrc + src.Length;
1308-
float* pSrcCurrent = psrc;
1309+
float* pValues = pSrc;
1310+
int length = src.Length;
13091311

1310-
Vector256<float> result256 = Avx.SetZeroVector256<float>();
1311-
1312-
while (pSrcCurrent + 8 <= pSrcEnd)
1312+
if (length < 8)
13131313
{
1314-
result256 = Avx.Add(result256, Avx.LoadVector256(pSrcCurrent));
1315-
pSrcCurrent += 8;
1314+
// Handle cases where we have less than 256-bits total and can't ever use SIMD acceleration.
1315+
1316+
float res = 0;
1317+
1318+
switch (length)
1319+
{
1320+
case 7: res += pValues[6]; goto case 6;
1321+
case 6: res += pValues[5]; goto case 5;
1322+
case 5: res += pValues[4]; goto case 4;
1323+
case 4: res += pValues[3]; goto case 3;
1324+
case 3: res += pValues[2]; goto case 2;
1325+
case 2: res += pValues[1]; goto case 1;
1326+
case 1: res += pValues[0]; break;
1327+
}
1328+
1329+
return res;
13161330
}
13171331

1318-
result256 = VectorSum256(in result256);
1319-
Vector128<float> resultPadded = Sse.AddScalar(Avx.GetLowerHalf(result256), GetHigh(result256));
1332+
Vector256<float> result = Avx.SetZeroVector256<float>();
13201333

1321-
Vector128<float> result128 = Sse.SetZeroVector128();
1334+
nuint address = (nuint)(pValues);
1335+
int misalignment = (int)(address % 32);
1336+
int remainder = 0;
13221337

1323-
if (pSrcCurrent + 4 <= pSrcEnd)
1338+
if ((misalignment & 3) != 0)
13241339
{
1325-
result128 = Sse.Add(result128, Sse.LoadVector128(pSrcCurrent));
1326-
pSrcCurrent += 4;
1340+
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
1341+
1342+
remainder = length % 8;
1343+
1344+
for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 8)
1345+
{
1346+
result = Avx.Add(result, Avx.LoadVector256(pValues));
1347+
}
13271348
}
1349+
else
1350+
{
1351+
if (misalignment != 0)
1352+
{
1353+
// Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
1354+
// masking any elements that will be included in the first aligned read
13281355

1329-
result128 = SseIntrinsics.VectorSum128(in result128);
1356+
misalignment >>= 2;
1357+
misalignment = 8 - misalignment;
13301358

1331-
while (pSrcCurrent < pSrcEnd)
1359+
Vector256<float> mask = Avx.LoadVector256(((float*)(pLeadingAlignmentMask)) + (misalignment * 8));
1360+
Vector256<float> temp = Avx.And(mask, Avx.LoadVector256(pValues));
1361+
result = Avx.Add(result, temp);
1362+
1363+
pValues += misalignment;
1364+
length -= misalignment;
1365+
}
1366+
1367+
if (length > 7)
1368+
{
1369+
// Handle all the 256-bit blocks that we can now that we have offset to an aligned address
1370+
1371+
remainder = length % 8;
1372+
1373+
for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 8)
1374+
{
1375+
// The JIT will only fold away unaligned loads due to the semantics behind
1376+
// the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
1377+
// modern hardware has unaligned loads that are as fast as aligned loads,
1378+
// when it doesn't cross a cache-line/page boundary, we will just assert
1379+
// that the alignment is correct and allow for the more-efficient codegen.
1380+
1381+
Contracts.Assert(((nuint)(pValues) % 32) == 0);
1382+
result = Avx.Add(result, Avx.LoadVector256(pValues));
1383+
}
1384+
}
1385+
else
1386+
{
1387+
// Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
1388+
// 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
1389+
// unaligned loads where we mask the input each time.
1390+
remainder = length;
1391+
}
1392+
}
1393+
1394+
if (remainder != 0)
13321395
{
1333-
result128 = Sse.AddScalar(result128, Sse.LoadScalarVector128(pSrcCurrent));
1334-
pSrcCurrent++;
1396+
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
1397+
// unaligned load will read to the end of the array and then mask out any elements already processed
1398+
1399+
pValues -= (8 - remainder);
1400+
1401+
Vector256<float> mask = Avx.LoadVector256(((float*)(pTrailingAlignmentMask)) + (remainder * 8));
1402+
Vector256<float> temp = Avx.And(mask, Avx.LoadVector256(pValues));
1403+
result = Avx.Add(result, temp);
13351404
}
13361405

1337-
return Sse.ConvertToSingle(Sse.AddScalar(result128, resultPadded));
1406+
// Sum all the elements together and return the result
1407+
result = VectorSum256(in result);
1408+
return Sse.ConvertToSingle(Sse.AddScalar(Avx.GetLowerHalf(result), GetHigh(result)));
13381409
}
13391410
}
13401411

src/Microsoft.ML.CpuMath/CpuMathUtils.netcoreapp.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -388,11 +388,11 @@ public static float Sum(ReadOnlySpan<float> src)
388388

389389
if (Avx.IsSupported)
390390
{
391-
return AvxIntrinsics.SumU(src);
391+
return AvxIntrinsics.Sum(src);
392392
}
393393
else if (Sse.IsSupported)
394394
{
395-
return SseIntrinsics.SumU(src);
395+
return SseIntrinsics.Sum(src);
396396
}
397397
else
398398
{

src/Microsoft.ML.CpuMath/Sse.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ public static float Sum(ReadOnlySpan<float> src)
246246
unsafe
247247
{
248248
fixed (float* psrc = &MemoryMarshal.GetReference(src))
249-
return Thunk.SumU(psrc, src.Length);
249+
return Thunk.Sum(psrc, src.Length);
250250
}
251251
}
252252

src/Microsoft.ML.CpuMath/SseIntrinsics.cs

+86-11
Original file line numberDiff line numberDiff line change
@@ -1140,29 +1140,104 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
11401140
}
11411141
}
11421142

1143-
public static unsafe float SumU(ReadOnlySpan<float> src)
1143+
public static unsafe float Sum(ReadOnlySpan<float> src)
11441144
{
1145-
fixed (float* psrc = &MemoryMarshal.GetReference(src))
1145+
fixed (float* pSrc = &MemoryMarshal.GetReference(src))
1146+
fixed (uint* pLeadingAlignmentMask = &LeadingAlignmentMask[0])
1147+
fixed (uint* pTrailingAlignmentMask = &TrailingAlignmentMask[0])
11461148
{
1147-
float* pSrcEnd = psrc + src.Length;
1148-
float* pSrcCurrent = psrc;
1149+
float* pValues = pSrc;
1150+
int length = src.Length;
1151+
1152+
if (length < 4)
1153+
{
1154+
// Handle cases where we have less than 128-bits total and can't ever use SIMD acceleration.
1155+
1156+
float res = 0;
1157+
1158+
switch (length)
1159+
{
1160+
case 3: res += pValues[2]; goto case 2;
1161+
case 2: res += pValues[1]; goto case 1;
1162+
case 1: res += pValues[0]; break;
1163+
}
1164+
1165+
return res;
1166+
}
11491167

11501168
Vector128<float> result = Sse.SetZeroVector128();
11511169

1152-
while (pSrcCurrent + 4 <= pSrcEnd)
1170+
nuint address = (nuint)(pValues);
1171+
int misalignment = (int)(address % 16);
1172+
int remainder = 0;
1173+
1174+
if ((misalignment & 3) != 0)
11531175
{
1154-
result = Sse.Add(result, Sse.LoadVector128(pSrcCurrent));
1155-
pSrcCurrent += 4;
1176+
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
1177+
1178+
remainder = length % 4;
1179+
1180+
for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
1181+
{
1182+
result = Sse.Add(result, Sse.LoadVector128(pValues));
1183+
}
11561184
}
1185+
else
1186+
{
1187+
if (misalignment != 0)
1188+
{
1189+
// Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
1190+
// masking any elements that will be included in the first aligned read
11571191

1158-
result = VectorSum128(in result);
1192+
misalignment >>= 2;
1193+
misalignment = 4 - misalignment;
11591194

1160-
while (pSrcCurrent < pSrcEnd)
1195+
Vector128<float> mask = Sse.LoadVector128(((float*)(pLeadingAlignmentMask)) + (misalignment * 4));
1196+
Vector128<float> temp = Sse.And(mask, Sse.LoadVector128(pValues));
1197+
result = Sse.Add(result, temp);
1198+
1199+
pValues += misalignment;
1200+
length -= misalignment;
1201+
}
1202+
1203+
if (length > 3)
1204+
{
1205+
// Handle all the 128-bit blocks that we can now that we have offset to an aligned address
1206+
1207+
remainder = length % 4;
1208+
1209+
for (float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
1210+
{
1211+
// If we aren't using the VEX-encoding, the JIT will only fold away aligned loads
1212+
// (due to semantics of the legacy encoding).
1213+
// We don't need an assert, since the instruction will throw for unaligned inputs.
1214+
1215+
result = Sse.Add(result, Sse.LoadAlignedVector128(pValues));
1216+
}
1217+
}
1218+
else
1219+
{
1220+
// Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
1221+
// 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
1222+
// unaligned loads where we mask the input each time.
1223+
remainder = length;
1224+
}
1225+
}
1226+
1227+
if (remainder != 0)
11611228
{
1162-
result = Sse.AddScalar(result, Sse.LoadScalarVector128(pSrcCurrent));
1163-
pSrcCurrent++;
1229+
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
1230+
// unaligned load will read to the end of the array and then mask out any elements already processed
1231+
1232+
pValues -= (4 - remainder);
1233+
1234+
Vector128<float> mask = Sse.LoadVector128(((float*)(pTrailingAlignmentMask)) + (remainder * 4));
1235+
Vector128<float> temp = Sse.And(temp, Sse.LoadVector128(pValues));
1236+
result = Sse.Add(result, temp);
11641237
}
11651238

1239+
// Sum all the elements together and return the result
1240+
result = VectorSum128(in result);
11661241
return Sse.ConvertToSingle(result);
11671242
}
11681243
}

src/Microsoft.ML.CpuMath/Thunk.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ public static extern void MatMulP(/*const*/ float* pmat, /*const*/ int* pposSrc,
5353
public static extern void AddSU(/*const*/ float* ps, /*const*/ int* pi, float* pd, int c);
5454

5555
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
56-
public static extern float SumU(/*const*/ float* ps, int c);
56+
public static extern float Sum(/*const*/ float* pValues, int length);
5757

5858
[DllImport(NativePath), SuppressUnmanagedCodeSecurity]
5959
public static extern float SumSqU(/*const*/ float* ps, int c);

src/Native/CpuMathNative/Sse.cpp

+91-10
Original file line numberDiff line numberDiff line change
@@ -903,21 +903,102 @@ EXPORT_API(void) MulElementWiseU(_In_ const float * ps1, _In_ const float * ps2,
903903
}
904904
}
905905

906-
EXPORT_API(float) SumU(const float * ps, int c)
906+
EXPORT_API(float) Sum(const float* pValues, int length)
907907
{
908-
const float * psLim = ps + c;
908+
if (length < 4)
909+
{
910+
// Handle cases where we have less than 128-bits total and can't ever use SIMD acceleration.
909911

910-
__m128 res = _mm_setzero_ps();
911-
for (; ps + 4 <= psLim; ps += 4)
912-
res = _mm_add_ps(res, _mm_loadu_ps(ps));
912+
float result = 0;
913913

914-
res = _mm_hadd_ps(res, res);
915-
res = _mm_hadd_ps(res, res);
914+
switch (length)
915+
{
916+
case 3: result += pValues[2];
917+
case 2: result += pValues[1];
918+
case 1: result += pValues[0];
919+
}
916920

917-
for (; ps < psLim; ps++)
918-
res = _mm_add_ss(res, _mm_load_ss(ps));
921+
return result;
922+
}
919923

920-
return _mm_cvtss_f32(res);
924+
__m128 result = _mm_setzero_ps();
925+
926+
uintptr_t address = (uintptr_t)(pValues);
927+
uintptr_t misalignment = address % 16;
928+
929+
int remainder = 0;
930+
931+
if ((misalignment & 3) != 0)
932+
{
933+
// Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
934+
935+
remainder = length % 4;
936+
937+
for (const float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
938+
{
939+
__m128 temp = _mm_loadu_ps(pValues);
940+
result = _mm_add_ps(result, temp);
941+
}
942+
}
943+
else
944+
{
945+
if (misalignment != 0)
946+
{
947+
// Handle cases where the data is not 128-bit aligned by doing an unaligned read and then
948+
// masking any elements that will be included in the first aligned read
949+
950+
misalignment >>= 2;
951+
misalignment = 4 - misalignment;
952+
953+
__m128 temp = _mm_loadu_ps(pValues);
954+
__m128 mask = _mm_loadu_ps(((float*)(&LeadingAlignmentMask)) + (misalignment * 4));
955+
temp = _mm_and_ps(temp, mask);
956+
result = _mm_add_ps(result, temp);
957+
958+
pValues += misalignment;
959+
length -= misalignment;
960+
}
961+
962+
if (length > 3)
963+
{
964+
// Handle all the 128-bit blocks that we can now that we have offset to an aligned address
965+
966+
remainder = length % 4;
967+
968+
for (const float* pEnd = pValues + (length - remainder); pValues < pEnd; pValues += 4)
969+
{
970+
__m128 temp = _mm_load_ps(pValues);
971+
result = _mm_add_ps(result, temp);
972+
}
973+
}
974+
else
975+
{
976+
// Handle the "worst-case" scenario, which is when we have 4-8 elements and the input is not
977+
// 128-bit aligned. This means we can't do any aligned loads and will just end up doing two
978+
// unaligned loads where we mask the input each time.
979+
remainder = length;
980+
}
981+
}
982+
983+
if (remainder != 0)
984+
{
985+
// Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
986+
// unaligned load will read to the end of the array and then mask out any elements already processed
987+
988+
pValues -= (4 - remainder);
989+
990+
__m128 temp = _mm_loadu_ps(pValues);
991+
__m128 mask = _mm_loadu_ps(((float*)(&TrailingAlignmentMask)) + (remainder * 4));
992+
temp = _mm_and_ps(temp, mask);
993+
result = _mm_add_ps(result, temp);
994+
}
995+
996+
// Sum all the elements together and return the result
997+
998+
result = _mm_add_ps(result, _mm_movehl_ps(result, result));
999+
result = _mm_add_ps(result, _mm_shuffle_ps(result, result, 0xB1));
1000+
1001+
return _mm_cvtss_f32(result);
9211002
}
9221003

9231004
EXPORT_API(float) SumSqU(const float * ps, int c)

0 commit comments

Comments
 (0)