7
7
#include < iostream>
8
8
#include < sycl/sycl.hpp>
9
9
10
- // using joint_matrix = sycl::ext::oneapi::experimental::matrix;
11
10
using use = sycl::ext::oneapi::experimental::matrix::use;
12
11
using layout = sycl::ext::oneapi::experimental::matrix::layout;
13
12
using bfloat16 = sycl::ext::oneapi::bfloat16;
14
13
15
- # define SG_SZ 16
14
+ constexpr size_t SG_SZ = 16 ;
16
15
17
- # define TM 8
18
- # define TN SG_SZ
19
- # define TK 16
16
+ constexpr size_t TM = 8 ;
17
+ constexpr size_t TN = SG_SZ;
18
+ constexpr size_t TK = 16 ;
20
19
21
- #define BF16_EPSILON 0.00781250
20
+ constexpr float ALPHA = 2.0 ;
21
+
22
+ constexpr float BF16_EPSILON = 0.00781250 ;
22
23
23
24
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
24
25
private:
@@ -42,10 +43,9 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
42
43
43
44
sycl::queue q;
44
45
q.submit ([&](sycl::handler &cgh) {
45
- sycl::accessor accC (bufC, cgh, sycl::read_write, sycl::no_init );
46
+ sycl::accessor accC (bufC, cgh, sycl::read_write);
46
47
sycl::accessor accA (bufA, cgh, sycl::read_only);
47
48
sycl::accessor accB (bufB, cgh, sycl::read_only);
48
-
49
49
cgh.parallel_for (
50
50
sycl::nd_range<2 >({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ}),
51
51
[=](sycl::nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]]
@@ -66,30 +66,32 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
66
66
// For B, we assume B has been already VNNIed.
67
67
sycl::ext::oneapi::experimental::matrix::joint_matrix<
68
68
sycl::sub_group, bfloat16, use::b, TK, TN,
69
- sycl::ext::intel::experimental::matrix:: layout::packed >
69
+ layout::ext_intel_packed >
70
70
sub_b;
71
71
sycl::ext::oneapi::experimental::matrix::joint_matrix<
72
72
sycl::sub_group, float , use::accumulator, TM, TN>
73
73
sub_c;
74
74
75
- joint_matrix_load (sg, sub_c,
76
- accC.get_pointer () + (sg_startx * TM) * N +
77
- sg_starty / SG_SZ * TN,
78
- N, layout::row_major);
79
- for (int k = 0 ; k < K / TK; k += 1 ) { //
75
+ joint_matrix_fill (sg, sub_c, 1.0 );
76
+ for (int k = 0 ; k < K / TK; k += 1 ) {
80
77
joint_matrix_load (
81
- sg, sub_a, accA.get_pointer () + (sg_startx * TM) * K + k * TK,
78
+ sg, sub_a,
79
+ accA.template get_multi_ptr <sycl::access ::decorated::no>() +
80
+ (sg_startx * TM) * K + k * TK,
82
81
K);
83
- joint_matrix_load (sg, sub_b,
84
- accB.get_pointer () + (k * TK / 2 ) * (N * 2 ) +
85
- sg_starty / SG_SZ * TN * 2 ,
86
- N * 2 );
87
- sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
82
+ joint_matrix_load (
83
+ sg, sub_b,
84
+ accB.template get_multi_ptr <sycl::access ::decorated::no>() +
85
+ (k * TK / 2 ) * (N * 2 ) + sg_starty / SG_SZ * TN * 2 ,
86
+ N * 2 );
87
+ joint_matrix_mad (sg, sub_c, sub_a, sub_b, sub_c);
88
88
}
89
- joint_matrix_store (sg, sub_c,
90
- accC.get_pointer () + (sg_startx * TM) * N +
91
- sg_starty / SG_SZ * TN,
92
- N, layout::row_major);
89
+ joint_matrix_apply (sg, sub_c, [=](float &x) { x *= ALPHA; });
90
+ joint_matrix_store (
91
+ sg, sub_c,
92
+ accC.template get_multi_ptr <sycl::access ::decorated::no>() +
93
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
94
+ N, layout::row_major);
93
95
}); // parallel for
94
96
}).wait ();
95
97
// kernel end
@@ -100,53 +102,43 @@ static constexpr size_t MATRIX_N = TN * 2;
100
102
static constexpr size_t MATRIX_K = TK * 2 ;
101
103
bfloat16 A[MATRIX_M][MATRIX_K];
102
104
bfloat16 B[MATRIX_K / 2 ][MATRIX_N * 2 ];
103
- unsigned short Aref[MATRIX_M][MATRIX_K];
104
- unsigned short Bref[MATRIX_K / 2 ][MATRIX_N * 2 ];
105
105
float C[MATRIX_M][MATRIX_N];
106
106
float D[MATRIX_M][MATRIX_N];
107
107
108
- float make_fp32 (short x) {
109
- unsigned int y = x ;
108
+ float make_fp32 (bfloat16 x) {
109
+ unsigned int y = *(( int *)&x) ;
110
110
y = y << 16 ;
111
111
float *res = reinterpret_cast <float *>(&y);
112
112
return *res;
113
113
}
114
114
115
- unsigned short make_bf16 (float x) {
116
- int *res = reinterpret_cast <int *>(&x);
117
- *res = *res >> 16 ;
118
- return (unsigned short )*res;
119
- }
120
-
121
115
void matrix_multiply_ref (int *A_mem, int *B_mem, int *C_mem, int M, int N,
122
116
int K) {
123
117
for (int m = 0 ; m < M; m++)
124
118
for (int n = 0 ; n < N; n++) {
125
119
for (int k = 0 ; k < K; k++) {
126
- short *va = (short *)(A_mem + m * K + k);
127
- short *vb = (short *)(B_mem + k * N + n);
120
+ // Because B was assumed VNNIed
121
+ bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
122
+ bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
128
123
float acc = *((float *)(C_mem + m * N + n));
129
124
for (int i = 0 ; i < 2 ; i++) {
130
125
acc += (make_fp32 (va[i]) * make_fp32 (vb[i]));
131
126
}
132
127
*((float *)(C_mem + m * N + n)) = acc;
133
128
}
129
+ *((float *)(C_mem + m * N + n)) *= ALPHA;
134
130
}
135
131
}
136
132
137
133
int main () {
138
134
for (int i = 0 ; i < MATRIX_M; i++) {
139
135
for (int j = 0 ; j < MATRIX_K; j++) {
140
- // bfloat16 is created using unsigned short since conversion from float to
141
- // bfloat16 is not supported on the host side yet
142
136
A[i][j] = bfloat16 (1 .0f * (i + j));
143
- Aref[i][j] = make_bf16 (1 .0f * (i + j));
144
137
}
145
138
}
146
139
for (int i = 0 ; i < MATRIX_K / 2 ; i++) {
147
140
for (int j = 0 ; j < MATRIX_N * 2 ; j++) {
148
141
B[i][j] = bfloat16 (2 .0f * i + 3 .0f * j);
149
- Bref[i][j] = make_bf16 (2 .0f * i + 3 .0f * j);
150
142
}
151
143
}
152
144
for (int i = 0 ; i < MATRIX_M; i++) {
@@ -161,13 +153,13 @@ int main() {
161
153
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA ((bfloat16 *)&A);
162
154
big_matrix<bfloat16, MATRIX_K / 2 , MATRIX_N * 2 > MB ((bfloat16 *)&B);
163
155
matrix_multiply (MC, MA, MB);
164
- matrix_multiply_ref ((int32_t *)Aref , (int32_t *)Bref , (int32_t *)D, MATRIX_M,
156
+ matrix_multiply_ref ((int32_t *)A , (int32_t *)B , (int32_t *)D, MATRIX_M,
165
157
MATRIX_N, MATRIX_K / 2 );
166
158
167
159
bool res = true ;
168
160
for (int i = 0 ; i < MATRIX_M; i++) {
169
161
for (int j = 0 ; j < MATRIX_N; j++) {
170
- if ((fabs (C[i][j]) - fabs ( D[i][j])) > BF16_EPSILON)
162
+ if ((fabs (C[i][j] - D[i][j])) > BF16_EPSILON)
171
163
res = false ;
172
164
}
173
165
}
0 commit comments