Skip to content

Commit d217052

Browse files
authored
[GPU opt guide][SYCL][joint matrix] update the test to match the guide (#2147)
1 parent 4bb588d commit d217052

File tree

1 file changed

+34
-42
lines changed

1 file changed

+34
-42
lines changed

Publications/GPU-Opt-Guide/joint-matrix/joint-matrix.cpp

+34-42
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,19 @@
77
#include <iostream>
88
#include <sycl/sycl.hpp>
99

10-
// using joint_matrix = sycl::ext::oneapi::experimental::matrix;
1110
using use = sycl::ext::oneapi::experimental::matrix::use;
1211
using layout = sycl::ext::oneapi::experimental::matrix::layout;
1312
using bfloat16 = sycl::ext::oneapi::bfloat16;
1413

15-
#define SG_SZ 16
14+
constexpr size_t SG_SZ = 16;
1615

17-
#define TM 8
18-
#define TN SG_SZ
19-
#define TK 16
16+
constexpr size_t TM = 8;
17+
constexpr size_t TN = SG_SZ;
18+
constexpr size_t TK = 16;
2019

21-
#define BF16_EPSILON 0.00781250
20+
constexpr float ALPHA = 2.0;
21+
22+
constexpr float BF16_EPSILON = 0.00781250;
2223

2324
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
2425
private:
@@ -42,10 +43,9 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
4243

4344
sycl::queue q;
4445
q.submit([&](sycl::handler &cgh) {
45-
sycl::accessor accC(bufC, cgh, sycl::read_write, sycl::no_init);
46+
sycl::accessor accC(bufC, cgh, sycl::read_write);
4647
sycl::accessor accA(bufA, cgh, sycl::read_only);
4748
sycl::accessor accB(bufB, cgh, sycl::read_only);
48-
4949
cgh.parallel_for(
5050
sycl::nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
5151
[=](sycl::nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
@@ -66,30 +66,32 @@ void matrix_multiply(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
6666
// For B, we assume B has been already VNNIed.
6767
sycl::ext::oneapi::experimental::matrix::joint_matrix<
6868
sycl::sub_group, bfloat16, use::b, TK, TN,
69-
sycl::ext::intel::experimental::matrix::layout::packed>
69+
layout::ext_intel_packed>
7070
sub_b;
7171
sycl::ext::oneapi::experimental::matrix::joint_matrix<
7272
sycl::sub_group, float, use::accumulator, TM, TN>
7373
sub_c;
7474

75-
joint_matrix_load(sg, sub_c,
76-
accC.get_pointer() + (sg_startx * TM) * N +
77-
sg_starty / SG_SZ * TN,
78-
N, layout::row_major);
79-
for (int k = 0; k < K / TK; k += 1) { //
75+
joint_matrix_fill(sg, sub_c, 1.0);
76+
for (int k = 0; k < K / TK; k += 1) {
8077
joint_matrix_load(
81-
sg, sub_a, accA.get_pointer() + (sg_startx * TM) * K + k * TK,
78+
sg, sub_a,
79+
accA.template get_multi_ptr<sycl::access::decorated::no>() +
80+
(sg_startx * TM) * K + k * TK,
8281
K);
83-
joint_matrix_load(sg, sub_b,
84-
accB.get_pointer() + (k * TK / 2) * (N * 2) +
85-
sg_starty / SG_SZ * TN * 2,
86-
N * 2);
87-
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
82+
joint_matrix_load(
83+
sg, sub_b,
84+
accB.template get_multi_ptr<sycl::access::decorated::no>() +
85+
(k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2,
86+
N * 2);
87+
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
8888
}
89-
joint_matrix_store(sg, sub_c,
90-
accC.get_pointer() + (sg_startx * TM) * N +
91-
sg_starty / SG_SZ * TN,
92-
N, layout::row_major);
89+
joint_matrix_apply(sg, sub_c, [=](float &x) { x *= ALPHA; });
90+
joint_matrix_store(
91+
sg, sub_c,
92+
accC.template get_multi_ptr<sycl::access::decorated::no>() +
93+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
94+
N, layout::row_major);
9395
}); // parallel for
9496
}).wait();
9597
// kernel end
@@ -100,53 +102,43 @@ static constexpr size_t MATRIX_N = TN * 2;
100102
static constexpr size_t MATRIX_K = TK * 2;
101103
bfloat16 A[MATRIX_M][MATRIX_K];
102104
bfloat16 B[MATRIX_K / 2][MATRIX_N * 2];
103-
unsigned short Aref[MATRIX_M][MATRIX_K];
104-
unsigned short Bref[MATRIX_K / 2][MATRIX_N * 2];
105105
float C[MATRIX_M][MATRIX_N];
106106
float D[MATRIX_M][MATRIX_N];
107107

108-
float make_fp32(short x) {
109-
unsigned int y = x;
108+
float make_fp32(bfloat16 x) {
109+
unsigned int y = *((int *)&x);
110110
y = y << 16;
111111
float *res = reinterpret_cast<float *>(&y);
112112
return *res;
113113
}
114114

115-
unsigned short make_bf16(float x) {
116-
int *res = reinterpret_cast<int *>(&x);
117-
*res = *res >> 16;
118-
return (unsigned short)*res;
119-
}
120-
121115
void matrix_multiply_ref(int *A_mem, int *B_mem, int *C_mem, int M, int N,
122116
int K) {
123117
for (int m = 0; m < M; m++)
124118
for (int n = 0; n < N; n++) {
125119
for (int k = 0; k < K; k++) {
126-
short *va = (short *)(A_mem + m * K + k);
127-
short *vb = (short *)(B_mem + k * N + n);
120+
// Because B was assumed VNNIed
121+
bfloat16 *va = (bfloat16 *)(A_mem + m * K + k);
122+
bfloat16 *vb = (bfloat16 *)(B_mem + k * N + n);
128123
float acc = *((float *)(C_mem + m * N + n));
129124
for (int i = 0; i < 2; i++) {
130125
acc += (make_fp32(va[i]) * make_fp32(vb[i]));
131126
}
132127
*((float *)(C_mem + m * N + n)) = acc;
133128
}
129+
*((float *)(C_mem + m * N + n)) *= ALPHA;
134130
}
135131
}
136132

137133
int main() {
138134
for (int i = 0; i < MATRIX_M; i++) {
139135
for (int j = 0; j < MATRIX_K; j++) {
140-
// bfloat16 is created using unsigned short since conversion from float to
141-
// bfloat16 is not supported on the host side yet
142136
A[i][j] = bfloat16(1.0f * (i + j));
143-
Aref[i][j] = make_bf16(1.0f * (i + j));
144137
}
145138
}
146139
for (int i = 0; i < MATRIX_K / 2; i++) {
147140
for (int j = 0; j < MATRIX_N * 2; j++) {
148141
B[i][j] = bfloat16(2.0f * i + 3.0f * j);
149-
Bref[i][j] = make_bf16(2.0f * i + 3.0f * j);
150142
}
151143
}
152144
for (int i = 0; i < MATRIX_M; i++) {
@@ -161,13 +153,13 @@ int main() {
161153
big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA((bfloat16 *)&A);
162154
big_matrix<bfloat16, MATRIX_K / 2, MATRIX_N * 2> MB((bfloat16 *)&B);
163155
matrix_multiply(MC, MA, MB);
164-
matrix_multiply_ref((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
156+
matrix_multiply_ref((int32_t *)A, (int32_t *)B, (int32_t *)D, MATRIX_M,
165157
MATRIX_N, MATRIX_K / 2);
166158

167159
bool res = true;
168160
for (int i = 0; i < MATRIX_M; i++) {
169161
for (int j = 0; j < MATRIX_N; j++) {
170-
if ((fabs(C[i][j]) - fabs(D[i][j])) > BF16_EPSILON)
162+
if ((fabs(C[i][j] - D[i][j])) > BF16_EPSILON)
171163
res = false;
172164
}
173165
}

0 commit comments

Comments
 (0)