Skip to content

Commit a222fa9

Browse files
committed
add int8/tf32 transpose A intrinsic
1 parent f6ca3e7 commit a222fa9

File tree

3 files changed

+158
-7
lines changed

3 files changed

+158
-7
lines changed

include/cute/arch/xe_copy_1B.hpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ SYCL_DEVICE_BUILTIN(
106106
intptr_t baseoffset, int width_minus_one, int height_minus_one,
107107
int pitch_minus_one, cute::intel::coord_t coord));
108108

109+
// 8bits NO transform transpose
110+
SYCL_DEVICE_BUILTIN(
111+
cute::intel::ushort8 __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
112+
intptr_t baseoffset, int width_minus_one, int height_minus_one,
113+
int pitch_minus_one, cute::intel::coord_t coord, int cache = 0));
114+
SYCL_DEVICE_BUILTIN(
115+
cute::intel::ushort4 __builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
116+
intptr_t baseoffset, int width_minus_one, int height_minus_one,
117+
int pitch_minus_one, cute::intel::coord_t coord, int cache = 0));
109118

110119
// 8bits VNNI transform No transpose
111120
SYCL_DEVICE_BUILTIN(
@@ -443,6 +452,45 @@ struct XE_2D_U8x32x32_LD_N {
443452
}
444453
};
445454

455+
struct XE_2D_U8x32x4_LD_T {
456+
using BlockShape = Shape<_4, _32>;
457+
using inst_dtype = uint8_t;
458+
static constexpr bool is_transpose = true;
459+
460+
template <class T>
461+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
462+
int height, int pitch, intel::coord_t coord,
463+
T *dst) {
464+
#if defined(SYCL_INTEL_TARGET)
465+
static_assert(sizeof(T) == 1, "Expected T to have size 1");
466+
*reinterpret_cast<intel::ushort4 *>(dst) =
467+
__builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k4(
468+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
469+
#else
470+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
471+
#endif
472+
}
473+
};
474+
475+
struct XE_2D_U8x32x8_LD_T {
476+
using BlockShape = Shape<_8, _32>;
477+
using inst_dtype = uint8_t;
478+
static constexpr bool is_transpose = true;
479+
480+
template <class T>
481+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
482+
int height, int pitch, intel::coord_t coord,
483+
T *dst) {
484+
#if defined(SYCL_INTEL_TARGET)
485+
static_assert(sizeof(T) == 1, "Expected T to have size 1");
486+
*reinterpret_cast<intel::ushort8 *>(dst) =
487+
__builtin_IB_subgroup_block_read_cacheopts_transpose_u8_m32k8(
488+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
489+
#else
490+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
491+
#endif
492+
}
493+
};
446494
struct XE_2D_U4x16x16_LD_T {
447495
using BlockShape = Shape<_16, _16>;
448496
using inst_dtype = uint32_t;

include/cute/arch/xe_copy_4B.hpp

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,10 @@ SYCL_DEVICE_BUILTIN(
117117
int pitch_minus_one, cute::intel::coord_t coord));
118118

119119
// 32bits No transform No transpose
120-
SYCL_DEVICE_BUILTIN(cute::intel::uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(
121-
intptr_t baseoffset, int width_minus_one, int height_minus_one,
122-
int pitch_minus_one, cute::intel::coord_t coord));
120+
SYCL_DEVICE_BUILTIN(
121+
cute::intel::uint __builtin_IB_subgroup_block_read_flat_u32_m1k16v1(
122+
intptr_t baseoffset, int width_minus_one, int height_minus_one,
123+
int pitch_minus_one, cute::intel::coord_t coord));
123124
SYCL_DEVICE_BUILTIN(
124125
cute::intel::uint2 __builtin_IB_subgroup_block_read_flat_u32_m2k16v1(
125126
intptr_t baseoffset, int width_minus_one, int height_minus_one,
@@ -142,9 +143,10 @@ SYCL_DEVICE_BUILTIN(
142143
int pitch_minus_one, cute::intel::coord_t coord));
143144

144145
// 32bits No transform Transpose
145-
SYCL_DEVICE_BUILTIN(cute::intel::uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(
146-
intptr_t baseoffset, int width_minus_one, int height_minus_one,
147-
int pitch_minus_one, cute::intel::coord_t coord));
146+
SYCL_DEVICE_BUILTIN(
147+
cute::intel::uint __builtin_IB_subgroup_block_read_flat_transpose_u32_k1(
148+
intptr_t baseoffset, int width_minus_one, int height_minus_one,
149+
int pitch_minus_one, cute::intel::coord_t coord));
148150
SYCL_DEVICE_BUILTIN(
149151
cute::intel::uint2 __builtin_IB_subgroup_block_read_flat_transpose_u32_k2(
150152
intptr_t baseoffset, int width_minus_one, int height_minus_one,
@@ -157,6 +159,10 @@ SYCL_DEVICE_BUILTIN(
157159
cute::intel::uint8 __builtin_IB_subgroup_block_read_flat_transpose_u32_k8(
158160
intptr_t baseoffset, int width_minus_one, int height_minus_one,
159161
int pitch_minus_one, cute::intel::coord_t coord));
162+
SYCL_DEVICE_BUILTIN(
163+
cute::intel::uint4 __builtin_IB_subgroup_block_read_flat_transpose_u32_m8k8(
164+
intptr_t baseoffset, int width_minus_one, int height_minus_one,
165+
int pitch_minus_one, cute::intel::coord_t coord));
160166

161167
// 32bits
162168
SYCL_DEVICE_BUILTIN(void __builtin_IB_subgroup_block_write_flat_u32_m1k16v1(
@@ -710,6 +716,27 @@ struct XE_2D_U32x16x8_LD_T {
710716
};
711717
};
712718

719+
struct XE_2D_TF32x8x8_LD_T {
720+
using BlockShape = Shape<_8, _8>;
721+
using ValueShape = Shape<_4, _16>;
722+
723+
static constexpr bool is_transpose = true;
724+
725+
template <class T>
726+
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
727+
int height, int pitch, intel::coord_t coord,
728+
T *dst) {
729+
#if defined(SYCL_INTEL_TARGET)
730+
static_assert(sizeof(T) == 4, "Expected T to have size 4");
731+
*reinterpret_cast<intel::uint4 *>(dst) =
732+
__builtin_IB_subgroup_block_read_flat_transpose_u32_m8k8(
733+
(intptr_t)(baseoffset), width - 1, height - 1, pitch - 1, coord);
734+
#else
735+
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-PVC hardware");
736+
#endif
737+
}
738+
};
739+
713740
struct XE_2D_U32x1x16_ST_N {
714741
using BlockShape = Shape<_1, _16>;
715742

include/cute/atom/copy_traits_xe.hpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1404,7 +1404,25 @@ struct Copy_Traits_<XE_2D_TF32x32x16_LD_N, args_t...>
14041404
};
14051405

14061406
template <class... args_t>
1407-
struct Copy_Traits_<XE_2D_U32x1x16_LD_N, args_t...>
1407+
struct Copy_Traits_<XE_2D_TF32x8x8_LD_T, args_t...>
1408+
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...> {
1409+
using ThrID = Layout<_16>;
1410+
// Map from (src-thr,src-val) to bit
1411+
using SrcLayout = Layout<Shape <_16, Shape <_4, _32>>,
1412+
Stride< _0, Stride<_32, _1>>>;
1413+
// Map from (dst-thr,dst-val) to bit
1414+
using DstLayout = Layout<Shape <_16, Shape <_4, _32>>,
1415+
Stride< _32, Stride<_32, _1>>>;
1416+
// Reference map from (thr,val) to bit
1417+
using RefLayout = DstLayout;
1418+
1419+
template <class... ArgTs>
1420+
Copy_Traits_(ArgTs... args)
1421+
: XE_2D_LD_Unpack<XE_2D_TF32x8x8_LD_T, args_t...>(args...) {}
1422+
};
1423+
1424+
template <class... args_t>
1425+
struct Copy_Traits<XE_2D_U32x1x16_LD_N, args_t...>
14081426
: XE_2D_LD_Unpack<XE_2D_U32x1x16_LD_N, args_t...> {
14091427
using ThrID = Layout<_16>;
14101428
// Map from (src-thr,src-val) to bit
@@ -1693,6 +1711,61 @@ struct Copy_Traits_<XE_2D_U16x16x16_LD_T, args_t...>
16931711
: XE_2D_LD_Unpack<XE_2D_U16x16x16_LD_T, args_t...>(args...) {}
16941712
};
16951713

1714+
template <class... args_t>
1715+
struct Copy_Traits_<XE_2D_U8x32x16_LD_T, args_t...>
1716+
: XE_2D_LD_Unpack<XE_2D_U8x32x16_LD_T, args_t...> {
1717+
using ThrID = Layout<_16>;
1718+
// Map from (src-thr,src-val) to bit
1719+
using SrcLayout = Layout<Shape <_16,_16>,
1720+
Stride< _0, _1>>;
1721+
// Map from (dst-thr,dst-val) to bit
1722+
using DstLayout = Layout<Shape < _16,Shape <_16,_16>>,
1723+
Stride<_256,Stride< _1,_16>>>;
1724+
// Reference map from (thr,val) to bit
1725+
using RefLayout = DstLayout;
1726+
1727+
template <class... ArgT>
1728+
Copy_Traits_(ArgT... args)
1729+
: XE_2D_LD_Unpack<XE_2D_U8x32x16_LD_T, args_t...>(args...) {}
1730+
};
1731+
1732+
template <class... args_t>
1733+
struct Copy_Traits_<XE_2D_U8x32x8_LD_T, args_t...>
1734+
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...> {
1735+
using ThrID = Layout<_16>;
1736+
// Map from (src-thr,src-val) to bit
1737+
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _8>>,
1738+
Stride<_0, Stride<_1, _8, _16>>>;
1739+
// Map from (dst-thr,dst-val) to bit
1740+
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _8>>,
1741+
Stride<_256,Stride<_1, _8, _16>>>;
1742+
// Reference map from (thr,val) to bit
1743+
using RefLayout = DstLayout;
1744+
1745+
template <class... ArgT>
1746+
Copy_Traits_(ArgT... args)
1747+
: XE_2D_LD_Unpack<XE_2D_U8x32x8_LD_T, args_t...>(args...) {}
1748+
};
1749+
1750+
template <class... args_t>
1751+
struct Copy_Traits<XE_2D_U8x32x4_LD_T, args_t...>
1752+
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...> {
1753+
using ThrID = Layout<_16>;
1754+
// Map from (src-thr,src-val) to bit
1755+
using SrcLayout = Layout<Shape <_16,Shape <_8, _2, _4>>,
1756+
Stride<_0, Stride<_1, _8, _16>>>;
1757+
// Map from (dst-thr,dst-val) to bit
1758+
using DstLayout = Layout<Shape < _16,Shape <_8, _2, _4>>,
1759+
Stride<_256,Stride<_1, _8, _16>>>;
1760+
// Reference map from (thr,val) to bit
1761+
using RefLayout = DstLayout;
1762+
1763+
template <class... ArgT>
1764+
Copy_Traits(ArgT... args)
1765+
: XE_2D_LD_Unpack<XE_2D_U8x32x4_LD_T, args_t...>(args...) {}
1766+
};
1767+
};
1768+
16961769
// template<class... args_t>
16971770
// struct Copy_Traits<XE_2D_U32x16x1_LD_T, args_t...>
16981771
// : XE_2D_LD_Unpack<XE_2D_U32x16x1_LD_T, args_t...> {
@@ -2237,6 +2310,8 @@ COPY_TRAIT_LD_DEF(XE_2D_U8x16x32_LD_N)
22372310
COPY_TRAIT_LD_DEF(XE_2D_U8x32x32_LD_N)
22382311
COPY_TRAIT_LD_DEF(XE_2D_U8x16x64_LD_N)
22392312
COPY_TRAIT_LD_DEF(XE_2D_U8x32x64_LD_N)
2313+
COPY_TRAIT_LD_DEF(XE_2D_U8x32x8_LD_T)
2314+
COPY_TRAIT_LD_DEF(XE_2D_U8x32x4_LD_T)
22402315
COPY_TRAIT_LD_DEF(XE_2D_U16x1x16_LD_N)
22412316
COPY_TRAIT_LD_DEF(XE_2D_U16x2x16_LD_N)
22422317
COPY_TRAIT_LD_DEF(XE_2D_U16x4x16_LD_N)
@@ -2279,6 +2354,7 @@ COPY_TRAIT_LD_DEF(XE_2D_U16x16x32_LD_V)
22792354
COPY_TRAIT_LD_DEF(XE_2D_U16x16x16_LD_T)
22802355
COPY_TRAIT_LD_DEF(XE_2D_TF32x16x16_LD_N)
22812356
COPY_TRAIT_LD_DEF(XE_2D_TF32x32x16_LD_N)
2357+
COPY_TRAIT_LD_DEF(XE_2D_TF32x8x8_LD_T)
22822358
COPY_TRAIT_LD_DEF(XE_2D_U4x32x64_LD_N)
22832359
COPY_TRAIT_LD_DEF(XE_2D_U4x16x64_LD_N)
22842360
COPY_TRAIT_LD_DEF(XE_2D_U4x32x16_LD_T)

0 commit comments

Comments
 (0)