Skip to content

Commit bd541ae

Browse files
committed
metal : optimize concat kernel and fix set kernel threads (llama/23411)
* metal : fix GGML_OP_SET kernel threads * tests : extend test_cpy to support different src/dst shapes Extend test_cpy to support different source and destination tensor shapes for CPY operations (reshaping), where the total number of elements must match. - Renamed ne -> ne_src, added ne_dst parameter (default: use src shape) - Added 50 new reshaping test cases covering 1D<->2D<->3D<->4D conversions - Tests exercise 1024 boundary, small shapes, and large dimensionality changes - Fixed dangling reference bug (storing & to temporary std::array) - Updated all existing test calls with permute/transpose args for compatibility Assisted-by: llama.cpp:local pi * metal : optimize concat kernel with row batching for small widths When ne0 < 256, batch multiple rows into a single threadgroup to improve occupancy. This avoids underutilizing the GPU when processing narrow tensors. - Dispatch nth = min(256, ne0) threads per group - Calculate nrptg (rows per threadgroup) to fill up to 256 threads - Update kernel index calculation to handle the row batching - Add boundary check for i1 >= ne1 Assisted-by: llama.cpp:local pi * tests : clean-up * tests : refactor CPY shape tests to use dimension permutations Replace 75 hardcoded test cases with a loop over permutations of {3, 5, 7, 32} (total elements: 3360). Each src permutation is tested against canonical sorted and reverse dst, skipping identical shapes. Covers F32, F16, and Q4_0 (when both src and dst ne0 == 32). Assisted-by: llama.cpp:local pi
1 parent a48e8e5 commit bd541ae

3 files changed

Lines changed: 113 additions & 42 deletions

File tree

src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -564,9 +564,20 @@ int ggml_metal_op_concat(ggml_metal_op_t ctx, int idx) {
564564
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2);
565565
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3);
566566

567-
const int nth = std::min(1024, ne0);
567+
int nth = std::min(256, ne0);
568568

569-
ggml_metal_encoder_dispatch_threadgroups(enc, ne1, ne2, ne3, nth, 1, 1);
569+
// when rows are small, we can batch them together in a single threadgroup
570+
int nrptg = 1;
571+
if (nth < 256) {
572+
nrptg = std::min((256 + nth - 1) / nth, ne1);
573+
if (nrptg * nth > 256) {
574+
nrptg = 256 / nth;
575+
}
576+
}
577+
578+
const int nw0 = (ne1 + nrptg - 1) / nrptg;
579+
580+
ggml_metal_encoder_dispatch_threadgroups(enc, nw0, ne2, ne3, nth, nrptg, 1);
570581

571582
return 1;
572583
}
@@ -1786,7 +1797,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
17861797
nk0 = ne10/ggml_blck_size(op->type);
17871798
}
17881799

1789-
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
1800+
int nth = std::min<int>(nk0*ne11, 256);
17901801

17911802
// when rows are small, we can batch them together in a single threadgroup
17921803
int nrptg = 1;
@@ -1797,7 +1808,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
17971808
nrptg = (nth + nk0 - 1)/nk0;
17981809
nth = nk0;
17991810

1800-
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
1811+
if (nrptg*nth > 256) {
18011812
nrptg--;
18021813
}
18031814
}

src/ggml-metal/ggml-metal.metal

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7486,7 +7486,11 @@ kernel void kernel_concat(
74867486

74877487
const int i3 = tgpig.z;
74887488
const int i2 = tgpig.y;
7489-
const int i1 = tgpig.x;
7489+
const int i1 = ntg.y == 1 ? tgpig.x : tgpig.x*ntg.y + tpitg.y;
7490+
7491+
if (i1 >= args.ne1) {
7492+
return;
7493+
}
74907494

74917495
int o[4] = {0, 0, 0, 0};
74927496
o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));

tests/test-backend-ops.cpp

Lines changed: 93 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2866,15 +2866,24 @@ struct test_set : public test_case {
28662866
struct test_cpy : public test_case {
28672867
const ggml_type type_src;
28682868
const ggml_type type_dst;
2869-
const std::array<int64_t, 4> ne;
2869+
const std::array<int64_t, 4> ne_src;
2870+
const std::array<int64_t, 4> ne_dst;
28702871
const std::array<int64_t, 4> permute_src;
28712872
const std::array<int64_t, 4> permute_dst;
28722873
bool _src_use_permute;
28732874
bool _dst_use_permute;
28742875
bool _src_transpose;
2876+
bool _use_dst_shape;
28752877

28762878
std::string vars() override {
2877-
return VARS_TO_STR6(type_src, type_dst, ne, permute_src, permute_dst, _src_transpose);
2879+
if (_use_dst_shape) {
2880+
return VARS_TO_STR7(type_src, type_dst, ne_src, ne_dst, permute_src, permute_dst, _src_transpose);
2881+
}
2882+
return VARS_TO_STR6(type_src, type_dst, ne_src, permute_src, permute_dst, _src_transpose);
2883+
}
2884+
2885+
int64_t total_elements() const {
2886+
return ne_src[0] * ne_src[1] * ne_src[2] * ne_src[3];
28782887
}
28792888

28802889
double max_nmse_err() override {
@@ -2899,7 +2908,7 @@ struct test_cpy : public test_case {
28992908
err_estimate /= 8.0f;
29002909
}
29012910
err_estimate *= err_estimate;
2902-
err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]);
2911+
err_estimate /= (150.0f*150.0f*0.25f)*float(total_elements());
29032912
return err_estimate;
29042913
}
29052914
return 1e-6;
@@ -2910,17 +2919,19 @@ struct test_cpy : public test_case {
29102919
}
29112920

29122921
test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32,
2913-
std::array<int64_t, 4> ne = {10, 10, 10, 1},
2922+
std::array<int64_t, 4> ne_src = {10, 10, 10, 1},
2923+
std::array<int64_t, 4> ne_dst = {-1, -1, -1, -1},
29142924
std::array<int64_t, 4> permute_src = {0, 0, 0, 0},
29152925
std::array<int64_t, 4> permute_dst = {0, 0, 0, 0},
29162926
bool transpose_src = false)
2917-
: type_src(type_src), type_dst(type_dst), ne(ne), permute_src(permute_src), permute_dst(permute_dst),
2927+
: type_src(type_src), type_dst(type_dst), ne_src(ne_src), ne_dst(ne_dst), permute_src(permute_src), permute_dst(permute_dst),
29182928
_src_use_permute(permute_src[0] + permute_src[1] + permute_src[2] + permute_src[3] > 0),
29192929
_dst_use_permute(permute_dst[0] + permute_dst[1] + permute_dst[2] + permute_dst[3] > 0),
2920-
_src_transpose(transpose_src){}
2930+
_src_transpose(transpose_src),
2931+
_use_dst_shape(ne_dst[0] >= 0 && ne_dst[1] >= 0 && ne_dst[2] >= 0 && ne_dst[3] >= 0){}
29212932

29222933
ggml_tensor * build_graph(ggml_context * ctx) override {
2923-
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data());
2934+
ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne_src.data());
29242935
ggml_set_param(src);
29252936
ggml_set_name(src, "src");
29262937

@@ -2934,7 +2945,8 @@ struct test_cpy : public test_case {
29342945
ggml_set_name(src, "src_transposed");
29352946
}
29362947

2937-
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, src->ne);
2948+
std::array<int64_t, 4> dst_ne = _use_dst_shape ? ne_dst : std::array<int64_t, 4>{src->ne[0], src->ne[1], src->ne[2], src->ne[3]};
2949+
ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, dst_ne.data());
29382950
ggml_set_name(dst, "dst");
29392951

29402952
if (_dst_use_permute) {
@@ -8040,42 +8052,72 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
80408052

80418053
for (int k = 1; k < 4; ++k) {
80428054
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}));
8043-
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 2, 1, 3}));
8044-
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {0, 3, 1, 2}, {0, 2, 1, 3}));
8055+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {-1,-1,-1,-1}, {0, 2, 1, 3}));
8056+
test_cases.emplace_back(new test_cpy(type, type, {k*nk, 2, 3, 4}, {-1,-1,-1,-1}, {0, 3, 1, 2}, {0, 2, 1, 3}));
80458057
}
80468058
}
80478059

80488060
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_F32}) {
80498061
for (ggml_type type_dst : all_types) {
80508062
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
8051-
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
8063+
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {-1,-1,-1,-1}, {0, 2, 1, 3})); // cpy by rows
80528064
}
80538065
}
80548066
for (ggml_type type_src : all_types) {
80558067
for (ggml_type type_dst : {GGML_TYPE_F32}) {
80568068
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));
8057-
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {0, 2, 1, 3})); // cpy by rows
8069+
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {-1,-1,-1,-1}, {0, 2, 1, 3})); // cpy by rows
80588070
}
80598071
}
80608072
for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
80618073
for (ggml_type type_dst : {GGML_TYPE_F16, GGML_TYPE_F32}) {
8062-
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {1, 0, 2, 3})); // cpy not-contiguous
8074+
test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 2, 3, 4}, {-1,-1,-1,-1}, {1, 0, 2, 3})); // cpy not-contiguous
80638075
}
80648076
}
80658077
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}));
8066-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {1, 0, 2, 3}));
8078+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_I32, {256, 2, 3, 4}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
80678079
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}));
8068-
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {1, 0, 2, 3}));
8069-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8070-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8071-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8072-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8073-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8074-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8075-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8076-
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8077-
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
8078-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
8080+
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_F32, {256, 2, 3, 4}, {-1,-1,-1,-1}, {1, 0, 2, 3}));
8081+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 3, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8082+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8083+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 3, 3}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8084+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 3, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8085+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {256, 4, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8086+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 4, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8087+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {256, 4, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8088+
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 4, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
8089+
test_cases.emplace_back(new test_cpy(GGML_TYPE_I32, GGML_TYPE_I32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
8090+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {256, 1, 4, 1}, {-1,-1,-1,-1}, {1, 2, 0, 3}, {0, 0, 0, 0}));
8091+
8092+
// CPY - different src/dst shapes (reshaping via CPY)
8093+
// Use permutations of {3, 5, 7, 32}. Total elements: 3*5*7*32 = 3360.
8094+
// Each src permutation is tested against canonical sorted and reverse dst (skip self).
8095+
{
8096+
std::array<int64_t, 4> dims = {3, 5, 7, 32};
8097+
std::sort(dims.begin(), dims.end());
8098+
std::array<int64_t, 4> canonical = dims;
8099+
std::array<int64_t, 4> reversed = {32, 7, 5, 3};
8100+
for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
8101+
std::array<int64_t, 4> cur = dims;
8102+
do {
8103+
if (cur != canonical) {
8104+
test_cases.emplace_back(new test_cpy(type, type, cur, canonical));
8105+
}
8106+
if (cur != reversed) {
8107+
test_cases.emplace_back(new test_cpy(type, type, cur, reversed));
8108+
}
8109+
if (cur[0] == 32 && type == GGML_TYPE_F32) {
8110+
if (canonical[0] == 32) {
8111+
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0, cur, canonical));
8112+
}
8113+
if (reversed[0] == 32) {
8114+
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_Q4_0, cur, reversed));
8115+
}
8116+
}
8117+
std::next_permutation(cur.begin(), cur.end());
8118+
} while (cur != canonical);
8119+
}
8120+
}
80798121

80808122
for (ggml_type type_dst : { GGML_TYPE_F32, GGML_TYPE_I32, GGML_TYPE_F16, GGML_TYPE_BF16 }) {
80818123
for (bool use_view_slice : { true, false }) {
@@ -8830,9 +8872,24 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
88308872
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1));
88318873
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2));
88328874
test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3));
8875+
88338876
test_cases.emplace_back(new test_pad());
88348877
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular
88358878
test_cases.emplace_back(new test_pad_ext());
8879+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1024, 1, 1, 1}, 1, 0, false));
8880+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1024, 2, 1, 1}, 1, 0, false));
8881+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1024, 16, 1, 1}, 0, 1, false));
8882+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1023, 1, 1, 1}, 1, 0, false));
8883+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1023, 8, 1, 1}, 1, 0, false));
8884+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1025, 1, 1, 1}, 1, 0, false));
8885+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {1025, 8, 1, 1}, 1, 0, false));
8886+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {2048, 1, 1, 1}, 1, 0, false));
8887+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {2048, 4, 1, 1}, 1, 0, false));
8888+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {2049, 1, 1, 1}, 1, 0, false));
8889+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {100, 1, 1, 1}, 100, 0, false));
8890+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {100, 1, 1, 1}, 0, 100, false));
8891+
test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {100, 100, 1, 1}, 50, 50, false));
8892+
88368893
test_cases.emplace_back(new test_pad_reflect_1d());
88378894
test_cases.emplace_back(new test_pad_reflect_1d(GGML_TYPE_F32, {3000, 384, 4, 1}));
88388895
test_cases.emplace_back(new test_roll());
@@ -9132,22 +9189,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
91329189
test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1}));
91339190

91349191
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F16, {512, 3072, 1, 1}));
9135-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {0, 2, 1, 3}));
9136-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {0, 2, 1, 3}));
9192+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {8192, 512, 2, 1}, {-1,-1,-1,-1}, {0, 2, 1, 3}));
9193+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {3072, 512, 2, 1}, {-1,-1,-1,-1}, {0, 2, 1, 3}));
91379194
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_Q4_0, {8192, 512, 2, 1}));
91389195
test_cases.emplace_back(new test_cpy(GGML_TYPE_Q4_0, GGML_TYPE_F32, {8192, 512, 2, 1}));
91399196

9140-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9141-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9142-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9143-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9144-
9145-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9146-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9147-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9148-
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9149-
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9197+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9198+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9199+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
9200+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {-1,-1,-1,-1}, {1, 0, 2, 3}, {0, 0, 0, 0}));
91509201

9202+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768*1024, 256, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9203+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, GGML_TYPE_F32, {768, 1024, 256, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9204+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768*1024, 256, 1, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9205+
test_cases.emplace_back(new test_cpy(GGML_TYPE_F16, GGML_TYPE_F16, {768, 1024, 256, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
9206+
test_cases.emplace_back(new test_cpy(GGML_TYPE_BF16, GGML_TYPE_BF16, {768, 1024, 256, 1}, {-1,-1,-1,-1}, {0, 0, 0, 0}, {0, 0, 0, 0}, true));
91519207

91529208
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {4096, 4096, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));
91539209
test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {12888, 256, 5, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f));

0 commit comments

Comments
 (0)