Skip to content

Commit ae4cca3

Browse files
CPU/CUDA: fix (GQA) mul mat back, add CUDA support
1 parent 564804b commit ae4cca3

File tree

7 files changed

+157
-62
lines changed

7 files changed

+157
-62
lines changed

ggml/src/ggml-cpu/ggml-cpu.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -7883,7 +7883,7 @@ static void ggml_compute_forward_out_prod_f32(
78837883

78847884
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
78857885
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7886-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7886+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
78877887

78887888
ggml_vec_mad_f32_unroll(ne0, nb01, nb11, d, s0, s1);
78897889
}
@@ -7892,7 +7892,7 @@ static void ggml_compute_forward_out_prod_f32(
78927892

78937893
float * s0 = (float *) ((char *) src0->data + ( i01*nb01 + i02*nb02 + i03*nb03));
78947894
float * s1 = (float *) ((char *) src1->data + (i1*nb10 + i11*nb11 + i12*nb12 + i13*nb13));
7895-
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
7895+
float * d = (float *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
78967896

78977897
ggml_vec_mad_f32(ne0, d, s0, *s1);
78987898
}

ggml/src/ggml-cpu/ggml-cpu.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,8 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
416416
case GGML_OP_IM2COL_BACK:
417417
return src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32;
418418
case GGML_OP_OUT_PROD:
419-
return (src0->type == GGML_TYPE_F32 || ggml_is_quantized(src0->type)) && src1->type == GGML_TYPE_F32;
419+
return (src0->type == GGML_TYPE_F32 || (ggml_is_quantized(src0->type) && src0->ne[2] == src1->ne[2] && src0->ne[3] == src1->ne[3])) &&
420+
src1->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
420421
default:
421422
return true;
422423
}

ggml/src/ggml-cuda/binbcast.cu

+30-24
Original file line numberDiff line numberDiff line change
@@ -93,26 +93,31 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s
9393

9494
template <typename T>
9595
static __global__ void k_repeat_back(
96-
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
97-
const int64_t ne0, const int64_t ne1, const int64_t ne2) {
96+
const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
97+
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
98+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3) {
9899

99-
const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
100-
const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y;
101-
const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z;
100+
const int64_t tid0 = int64_t(blockIdx.x)*blockDim.x + threadIdx.x;
101+
const int64_t tid1 = int64_t(blockIdx.y)*blockDim.y + threadIdx.y;
102+
const int64_t tid23 = int64_t(blockIdx.z)*blockDim.z + threadIdx.z;
103+
const int64_t tid2 = tid23 % ne2;
104+
const int64_t tid3 = tid23 / ne2;
102105

103106
if (tid0 >= ne0) {
104107
return;
105108
}
106109

107110
T sum = 0;
108-
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
109-
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
110-
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
111-
sum += src[i2*ne01*ne00 + i1*ne00 + i0];
111+
for (int64_t i3 = tid3; i3 < ne03; i3 += ne3) {
112+
for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) {
113+
for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) {
114+
for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) {
115+
sum += src[i3*s03 + i2*s02 + i1*s01 + i0*s00];
116+
}
112117
}
113118
}
114119
}
115-
dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
120+
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
116121
}
117122

118123
template<float (*bin_op)(const float, const float)>
@@ -274,12 +279,14 @@ struct bin_bcast_cuda {
274279

275280
template <typename T>
276281
static void repeat_back_cuda(
277-
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02,
278-
const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) {
282+
const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
283+
const size_t s00, const size_t s01, const size_t s02, const size_t s03,
284+
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
279285

280286
const dim3 block_dims(WARP_SIZE, 1, 1);
281-
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2);
282-
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2);
287+
const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2*ne3);
288+
k_repeat_back<T><<<block_nums, block_dims, 0, stream>>>
289+
(src, dst, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3);
283290
}
284291

285292
template<class op>
@@ -326,27 +333,26 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
326333
const ggml_tensor * src0 = dst->src[0];
327334

328335
GGML_ASSERT(src0->type == dst->type);
329-
GGML_ASSERT(ggml_is_contiguous(src0));
330336
GGML_ASSERT(ggml_is_contiguous(dst));
331337
GGML_ASSERT(ggml_can_repeat(dst, src0));
332338

333339
cudaStream_t stream = ctx.stream();
334340

335-
const int64_t ne00 = src0->ne[0];
336-
const int64_t ne01 = src0->ne[1];
337-
const int64_t ne02 = src0->ne[2];
338-
GGML_ASSERT(src0->ne[3] == 1);
341+
GGML_TENSOR_UNARY_OP_LOCALS;
342+
343+
GGML_ASSERT(ne2*ne3 <= (1 << 15));
339344

340-
const int64_t ne0 = dst->ne[0];
341-
const int64_t ne1 = dst->ne[1];
342-
const int64_t ne2 = dst->ne[2];
343-
GGML_ASSERT(dst->ne[3] == 1);
345+
const size_t ts = ggml_type_size(src0->type);
346+
const size_t s00 = nb00 / ts;
347+
const size_t s01 = nb01 / ts;
348+
const size_t s02 = nb02 / ts;
349+
const size_t s03 = nb03 / ts;
344350

345351
switch (dst->type) {
346352
case GGML_TYPE_F32: {
347353
const float * src0_d = (const float *) src0->data;
348354
float * dst_d = (float *) dst->data;
349-
repeat_back_cuda<float>(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream);
355+
repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s00, s01, s02, s03, ne0, ne1, ne2, ne3, stream);
350356
} break;
351357
default: {
352358
GGML_ASSERT(false);

ggml/src/ggml-cuda/ggml-cuda.cu

+1-1
Original file line numberDiff line numberDiff line change
@@ -3002,7 +3002,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
30023002
return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16;
30033003
} break;
30043004
case GGML_OP_REPEAT_BACK:
3005-
return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1;
3005+
return op->type == GGML_TYPE_F32 && (op->src[0]->ne[2]*op->src[0]->ne[3]) <= (1 << 15);
30063006
case GGML_OP_CONCAT:
30073007
{
30083008
ggml_type src0_type = op->src[0]->type;

ggml/src/ggml-cuda/out-prod.cu

+5-2
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
3434

3535
CUBLAS_CHECK(cublasSetStream(handle, stream));
3636

37+
const int64_t lda = nb01 / sizeof(float);
38+
const int64_t ldc = nb1 / sizeof(float);
39+
3740
const bool src1_T = ggml_is_transposed(src1);
3841
const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T;
3942
const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float);
@@ -57,9 +60,9 @@ void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
5760
CUBLAS_CHECK(
5861
cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op,
5962
ne0, ne1, ne01,
60-
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, ne00,
63+
&alpha, src0_d + (i3/dps3)*s03 + (i2/dps2)*s02, lda,
6164
src1_d + i3 *s13 + i2 *s12, ldb,
62-
&beta, dst_d + i3 *s3 + i2 *s2, ne0));
65+
&beta, dst_d + i3 *s3 + i2 *s2, ldc));
6366
}
6467
}
6568
}

ggml/src/ggml.c

+19-13
Original file line numberDiff line numberDiff line change
@@ -5339,7 +5339,7 @@ static void ggml_compute_backward(
53395339
} break;
53405340
case GGML_OP_MUL: {
53415341
if (src0_needs_grads) {
5342-
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad));
5342+
ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, src1));
53435343
}
53445344
if (src1_needs_grads) {
53455345
struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad);
@@ -5431,21 +5431,25 @@ static void ggml_compute_backward(
54315431
// src1.shape [n,p,qq,rr]
54325432

54335433
if (src0_needs_grads) {
5434-
struct ggml_tensor * s1_tg =
5434+
GGML_ASSERT(grad->ne[2] == src1->ne[2]);
5435+
GGML_ASSERT(grad->ne[3] == src1->ne[3]);
5436+
struct ggml_tensor * tmp =
54355437
ggml_out_prod(ctx, // [n,m,qq,rr]
54365438
src1, // [n,p,qq,rr]
54375439
grad); // [m,p,qq,rr]
5438-
const int64_t qq = s1_tg->ne[2];
5439-
const int64_t rr = s1_tg->ne[3];
5440-
const int64_t q1 = src0->ne[2];
5441-
const int64_t r1 = src0->ne[3];
5442-
const bool ne2_broadcasted = qq > q1;
5443-
const bool ne3_broadcasted = rr > r1;
5444-
if (ne2_broadcasted || ne3_broadcasted) {
5445-
// sum broadcast repetitions of s1_tg into shape of src0
5446-
s1_tg = ggml_repeat_back(ctx, s1_tg, src0);
5440+
if (!ggml_are_same_shape(tmp, src0)) {
5441+
GGML_ASSERT(tmp->ne[0] == src0->ne[0]);
5442+
GGML_ASSERT(tmp->ne[1] == src0->ne[1]);
5443+
GGML_ASSERT(tmp->ne[3] == 1);
5444+
5445+
const int64_t nr2 = tmp->ne[2] / src0->ne[2];
5446+
const size_t nb2 = tmp->nb[2] * nr2;
5447+
const size_t nb3 = tmp->nb[2];
5448+
5449+
tmp = ggml_view_4d(ctx, tmp, src0->ne[0], src0->ne[1], src0->ne[2], nr2, tmp->nb[1], nb2, nb3, 0);
5450+
tmp = ggml_repeat_back(ctx, tmp, src0);
54475451
}
5448-
ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/);
5452+
ggml_add_or_set(ctx, cgraph, isrc0, tmp);
54495453
}
54505454
if (src1_needs_grads) {
54515455
ggml_add_or_set(ctx, cgraph, isrc1,
@@ -5514,7 +5518,9 @@ static void ggml_compute_backward(
55145518
if (src0_needs_grads) {
55155519
GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0]));
55165520
GGML_ASSERT(ggml_is_contiguous(grad));
5517-
ggml_add_or_set(ctx, cgraph, isrc0, grad);
5521+
GGML_ASSERT(ggml_nelements(tensor) == ggml_nelements(src0));
5522+
ggml_add_or_set(ctx, cgraph, isrc0,
5523+
ggml_are_same_shape(tensor, src0) ? grad : ggml_reshape(ctx, grad, src0));
55185524
}
55195525
} break;
55205526
case GGML_OP_RESHAPE: {

tests/test-backend-ops.cpp

+98-19
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
13021302
}
13031303
};
13041304

1305+
// GGML_OP_REPEAT_BACK
1306+
struct test_repeat_back : public test_case {
1307+
const ggml_type type;
1308+
const std::array<int64_t, 4> ne;
1309+
const std::array<int, 4> nr;
1310+
const bool v; // whether src is a noncontiguous view
1311+
1312+
std::string vars() override {
1313+
return VARS_TO_STR4(type, ne, nr, v);
1314+
}
1315+
1316+
size_t op_size(ggml_tensor * t) override {
1317+
return ggml_nbytes(t) * 2;
1318+
}
1319+
1320+
test_repeat_back(ggml_type type = GGML_TYPE_F32,
1321+
std::array<int64_t, 4> ne = {8, 6, 4, 2},
1322+
std::array<int, 4> nr = {2, 2, 2, 2},
1323+
bool v = false)
1324+
: type(type), ne(ne), nr(nr), v(v) {}
1325+
1326+
ggml_tensor * build_graph(ggml_context * ctx) override {
1327+
ggml_tensor * src = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]);
1328+
ggml_set_name(src, "src");
1329+
1330+
if (v) {
1331+
GGML_ASSERT(ne[0] % 2 == 0);
1332+
GGML_ASSERT(ne[1] % 2 == 0);
1333+
GGML_ASSERT(ne[2] % 2 == 0);
1334+
GGML_ASSERT(ne[3] % 2 == 0);
1335+
GGML_ASSERT(nr[0] % 2 == 0 || nr[0] == 1);
1336+
GGML_ASSERT(nr[1] % 2 == 0 || nr[1] == 1);
1337+
GGML_ASSERT(nr[2] % 2 == 0 || nr[2] == 1);
1338+
GGML_ASSERT(nr[3] % 2 == 0 || nr[3] == 1);
1339+
1340+
const int64_t ne00 = nr[0] == 1 ? src->ne[0] : src->ne[0] / 2;
1341+
const int64_t ne01 = nr[1] == 1 ? src->ne[1] : src->ne[1] / 2;
1342+
const int64_t ne02 = nr[2] == 1 ? src->ne[2] : src->ne[2] / 2;
1343+
const int64_t ne03 = nr[3] == 1 ? src->ne[3] : src->ne[3] / 2;
1344+
1345+
src = ggml_view_4d(ctx, src, ne00, ne01, ne02, ne03, src->nb[1], src->nb[2], src->nb[3], 0);
1346+
}
1347+
1348+
ggml_tensor * target = ggml_new_tensor(ctx, type, 4, ne.data());
1349+
ggml_set_name(target, "target");
1350+
1351+
ggml_tensor * out = ggml_repeat_back(ctx, src, target);
1352+
ggml_set_name(out, "out");
1353+
1354+
return out;
1355+
}
1356+
};
1357+
13051358
// GGML_OP_DUP
13061359
struct test_dup : public test_case {
13071360
const ggml_type type;
@@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
18491902
return 5e-4;
18501903
}
18511904

1905+
int64_t grad_nmax() override {
1906+
return 20000;
1907+
}
1908+
18521909
uint64_t op_flops(ggml_tensor * t) override {
18531910
GGML_UNUSED(t);
18541911
return 2 * m * n * k * bs[0] * nr[0] * bs[1] * nr[1];
@@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
18781935

18791936
a = ggml_new_tensor_4d(ctx, type_a, ne_a[per[0]], ne_a[per[1]], ne_a[per[2]], ne_a[per[3]]);
18801937
b = ggml_new_tensor_4d(ctx, type_b, ne_b[per[0]], ne_b[per[1]], ne_b[per[2]], ne_b[per[3]]);
1881-
ggml_set_param(ctx, a);
1882-
ggml_set_param(ctx, b);
1938+
if (!ggml_is_quantized(type_a)) {
1939+
if (bs[1] == 1 && nr[1] == 1) {
1940+
ggml_set_param(ctx, a);
1941+
}
1942+
ggml_set_param(ctx, b);
1943+
}
18831944
ggml_set_name(a, "a");
18841945
ggml_set_name(b, "b");
18851946

@@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
18901951
} else {
18911952
a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
18921953
b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
1893-
ggml_set_param(ctx, a);
1894-
ggml_set_param(ctx, b);
1954+
if (!ggml_is_quantized(type_a)) {
1955+
if (bs[1] == 1 && nr[1] == 1) {
1956+
ggml_set_param(ctx, a);
1957+
}
1958+
ggml_set_param(ctx, b);
1959+
}
18951960
ggml_set_name(a, "a");
18961961
ggml_set_name(b, "b");
18971962
}
@@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
37983863
test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2}));
37993864
}
38003865

3866+
for (bool view : {false, true}) {
3867+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 1}, view));
3868+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
3869+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 2, 1, 1}, view));
3870+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 2, 1}, view));
3871+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_F32, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
3872+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I32, {8, 6, 4, 2}, {2, 1, 1, 1}, view));
3873+
test_cases.emplace_back(new test_repeat_back(GGML_TYPE_I16, {8, 6, 4, 2}, {1, 1, 1, 2}, view));
3874+
}
3875+
38013876
test_cases.emplace_back(new test_dup(GGML_TYPE_F32));
38023877
test_cases.emplace_back(new test_dup(GGML_TYPE_F16));
38033878
test_cases.emplace_back(new test_dup(GGML_TYPE_I32));
@@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39193994
for (ggml_type type_a : base_types) {
39203995
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
39213996
// test cases without permutation
3922-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1}));
3923-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1}));
3924-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1}));
3925-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1}));
3926-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1}));
3927-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2}));
3928-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2}));
3929-
3930-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1}));
3931-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1}));
3932-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1}));
3933-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1}));
3934-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1}));
3935-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2}));
3936-
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2}));
3997+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 1}));
3998+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {2, 1}));
3999+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {1, 1}, {1, 2}));
4000+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {1, 1}));
4001+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 1}, {2, 1}));
4002+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 1}));
4003+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 1}));
4004+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {1, 2}));
4005+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {3, 2}, {2, 2}));
4006+
4007+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 1}));
4008+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {2, 1}));
4009+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {1, 1}, {1, 2}));
4010+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {1, 1}));
4011+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 1}, {2, 1}));
4012+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 1}));
4013+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 1}));
4014+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {1, 2}));
4015+
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {3, 2}, {2, 2}));
39374016

39384017
// test cases with permutation
39394018
test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {2, 3}, {1, 1}, {0, 2, 1, 3}));

0 commit comments

Comments
 (0)