@@ -1300,41 +1300,112 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
1300
1300
}
1301
1301
}
1302
1302
1303
- public static unsafe float SumU ( ReadOnlySpan < float > src )
1303
+ public static unsafe float Sum ( ReadOnlySpan < float > src )
1304
1304
{
1305
- fixed ( float * psrc = & MemoryMarshal . GetReference ( src ) )
1305
+ fixed ( float * pSrc = & MemoryMarshal . GetReference ( src ) )
1306
+ fixed ( uint * pLeadingAlignmentMask = & LeadingAlignmentMask [ 0 ] )
1307
+ fixed ( uint * pTrailingAlignmentMask = & TrailingAlignmentMask [ 0 ] )
1306
1308
{
1307
- float * pSrcEnd = psrc + src . Length ;
1308
- float * pSrcCurrent = psrc ;
1309
+ float * pValues = pSrc ;
1310
+ int length = src . Length ;
1309
1311
1310
- Vector256 < float > result256 = Avx . SetZeroVector256 < float > ( ) ;
1311
-
1312
- while ( pSrcCurrent + 8 <= pSrcEnd )
1312
+ if ( length < 8 )
1313
1313
{
1314
- result256 = Avx . Add ( result256 , Avx . LoadVector256 ( pSrcCurrent ) ) ;
1315
- pSrcCurrent += 8 ;
1314
+ // Handle cases where we have less than 256-bits total and can't ever use SIMD acceleration.
1315
+
1316
+ float res = 0 ;
1317
+
1318
+ switch ( length )
1319
+ {
1320
+ case 7 : res += pValues [ 6 ] ; goto case 6 ;
1321
+ case 6 : res += pValues [ 5 ] ; goto case 5 ;
1322
+ case 5 : res += pValues [ 4 ] ; goto case 4 ;
1323
+ case 4 : res += pValues [ 3 ] ; goto case 3 ;
1324
+ case 3 : res += pValues [ 2 ] ; goto case 2 ;
1325
+ case 2 : res += pValues [ 1 ] ; goto case 1 ;
1326
+ case 1 : res += pValues [ 0 ] ; break ;
1327
+ }
1328
+
1329
+ return res ;
1316
1330
}
1317
1331
1318
- result256 = VectorSum256 ( in result256 ) ;
1319
- Vector128 < float > resultPadded = Sse . AddScalar ( Avx . GetLowerHalf ( result256 ) , GetHigh ( result256 ) ) ;
1332
+ Vector256 < float > result = Avx . SetZeroVector256 < float > ( ) ;
1320
1333
1321
- Vector128 < float > result128 = Sse . SetZeroVector128 ( ) ;
1334
+ nuint address = ( nuint ) ( pValues ) ;
1335
+ int misalignment = ( int ) ( address % 32 ) ;
1336
+ int remainder = 0 ;
1322
1337
1323
- if ( pSrcCurrent + 4 <= pSrcEnd )
1338
+ if ( ( misalignment & 3 ) != 0 )
1324
1339
{
1325
- result128 = Sse . Add ( result128 , Sse . LoadVector128 ( pSrcCurrent ) ) ;
1326
- pSrcCurrent += 4 ;
1340
+ // Handles cases where the data is not 32-bit aligned and we can't ever use aligned operations
1341
+
1342
+ remainder = length % 8 ;
1343
+
1344
+ for ( float * pEnd = pValues + ( length - remainder ) ; pValues < pEnd ; pValues += 8 )
1345
+ {
1346
+ result = Avx . Add ( result , Avx . LoadVector256 ( pValues ) ) ;
1347
+ }
1327
1348
}
1349
+ else
1350
+ {
1351
+ if ( misalignment != 0 )
1352
+ {
1353
+ // Handle cases where the data is not 256-bit aligned by doing an unaligned read and then
1354
+ // masking any elements that will be included in the first aligned read
1328
1355
1329
- result128 = SseIntrinsics . VectorSum128 ( in result128 ) ;
1356
+ misalignment >>= 2 ;
1357
+ misalignment = 8 - misalignment ;
1330
1358
1331
- while ( pSrcCurrent < pSrcEnd )
1359
+ Vector256 < float > mask = Avx . LoadVector256 ( ( ( float * ) ( pLeadingAlignmentMask ) ) + ( misalignment * 8 ) ) ;
1360
+ Vector256 < float > temp = Avx . And ( mask , Avx . LoadVector256 ( pValues ) ) ;
1361
+ result = Avx . Add ( result , temp ) ;
1362
+
1363
+ pValues += misalignment ;
1364
+ length -= misalignment ;
1365
+ }
1366
+
1367
+ if ( length > 7 )
1368
+ {
1369
+ // Handle all the 256-bit blocks that we can now that we have offset to an aligned address
1370
+
1371
+ remainder = length % 8 ;
1372
+
1373
+ for ( float * pEnd = pValues + ( length - remainder ) ; pValues < pEnd ; pValues += 8 )
1374
+ {
1375
+ // The JIT will only fold away unaligned loads due to the semantics behind
1376
+ // the VEX-encoding of the memory operand for `ins xmm, xmm, [mem]`. Since
1377
+ // modern hardware has unaligned loads that are as fast as aligned loads,
1378
+ // when it doesn't cross a cache-line/page boundary, we will just assert
1379
+ // that the alignment is correct and allow for the more-efficient codegen.
1380
+
1381
+ Contracts . Assert ( ( ( nuint ) ( pValues ) % 32 ) == 0 ) ;
1382
+ result = Avx . Add ( result , Avx . LoadVector256 ( pValues ) ) ;
1383
+ }
1384
+ }
1385
+ else
1386
+ {
1387
+ // Handle the "worst-case" scenario, which is when we have 8-16 elements and the input is not
1388
+ // 256-bit aligned. This means we can't do any aligned loads and will just end up doing two
1389
+ // unaligned loads where we mask the input each time.
1390
+ remainder = length ;
1391
+ }
1392
+ }
1393
+
1394
+ if ( remainder != 0 )
1332
1395
{
1333
- result128 = Sse . AddScalar ( result128 , Sse . LoadScalarVector128 ( pSrcCurrent ) ) ;
1334
- pSrcCurrent ++ ;
1396
+ // Handle any trailing elements that don't fit into a 128-bit block by moving back so that the next
1397
+ // unaligned load will read to the end of the array and then mask out any elements already processed
1398
+
1399
+ pValues -= ( 8 - remainder ) ;
1400
+
1401
+ Vector256 < float > mask = Avx . LoadVector256 ( ( ( float * ) ( pTrailingAlignmentMask ) ) + ( remainder * 8 ) ) ;
1402
+ Vector256 < float > temp = Avx . And ( mask , Avx . LoadVector256 ( pValues ) ) ;
1403
+ result = Avx . Add ( result , temp ) ;
1335
1404
}
1336
1405
1337
- return Sse . ConvertToSingle ( Sse . AddScalar ( result128 , resultPadded ) ) ;
1406
+ // Sum all the elements together and return the result
1407
+ result = VectorSum256 ( in result ) ;
1408
+ return Sse . ConvertToSingle ( Sse . AddScalar ( Avx . GetLowerHalf ( result ) , GetHigh ( result ) ) ) ;
1338
1409
}
1339
1410
}
1340
1411
0 commit comments