@@ -1302,6 +1302,59 @@ struct test_repeat : public test_case {
1302
1302
}
1303
1303
};
1304
1304
1305
+ // GGML_OP_REPEAT_BACK
1306
+ struct test_repeat_back : public test_case {
1307
+ const ggml_type type;
1308
+ const std::array<int64_t , 4 > ne;
1309
+ const std::array<int , 4 > nr;
1310
+ const bool v; // whether src is a noncontiguous view
1311
+
1312
+ std::string vars () override {
1313
+ return VARS_TO_STR4 (type, ne, nr, v);
1314
+ }
1315
+
1316
+ size_t op_size (ggml_tensor * t) override {
1317
+ return ggml_nbytes (t) * 2 ;
1318
+ }
1319
+
1320
+ test_repeat_back (ggml_type type = GGML_TYPE_F32,
1321
+ std::array<int64_t , 4 > ne = {8 , 6 , 4 , 2 },
1322
+ std::array<int , 4 > nr = {2 , 2 , 2 , 2 },
1323
+ bool v = false )
1324
+ : type(type), ne(ne), nr(nr), v(v) {}
1325
+
1326
+ ggml_tensor * build_graph (ggml_context * ctx) override {
1327
+ ggml_tensor * src = ggml_new_tensor_4d (ctx, type, ne[0 ]*nr[0 ], ne[1 ]*nr[1 ], ne[2 ]*nr[2 ], ne[3 ]*nr[3 ]);
1328
+ ggml_set_name (src, " src" );
1329
+
1330
+ if (v) {
1331
+ GGML_ASSERT (ne[0 ] % 2 == 0 );
1332
+ GGML_ASSERT (ne[1 ] % 2 == 0 );
1333
+ GGML_ASSERT (ne[2 ] % 2 == 0 );
1334
+ GGML_ASSERT (ne[3 ] % 2 == 0 );
1335
+ GGML_ASSERT (nr[0 ] % 2 == 0 || nr[0 ] == 1 );
1336
+ GGML_ASSERT (nr[1 ] % 2 == 0 || nr[1 ] == 1 );
1337
+ GGML_ASSERT (nr[2 ] % 2 == 0 || nr[2 ] == 1 );
1338
+ GGML_ASSERT (nr[3 ] % 2 == 0 || nr[3 ] == 1 );
1339
+
1340
+ const int64_t ne00 = nr[0 ] == 1 ? src->ne [0 ] : src->ne [0 ] / 2 ;
1341
+ const int64_t ne01 = nr[1 ] == 1 ? src->ne [1 ] : src->ne [1 ] / 2 ;
1342
+ const int64_t ne02 = nr[2 ] == 1 ? src->ne [2 ] : src->ne [2 ] / 2 ;
1343
+ const int64_t ne03 = nr[3 ] == 1 ? src->ne [3 ] : src->ne [3 ] / 2 ;
1344
+
1345
+ src = ggml_view_4d (ctx, src, ne00, ne01, ne02, ne03, src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1346
+ }
1347
+
1348
+ ggml_tensor * target = ggml_new_tensor (ctx, type, 4 , ne.data ());
1349
+ ggml_set_name (target, " target" );
1350
+
1351
+ ggml_tensor * out = ggml_repeat_back (ctx, src, target);
1352
+ ggml_set_name (out, " out" );
1353
+
1354
+ return out;
1355
+ }
1356
+ };
1357
+
1305
1358
// GGML_OP_DUP
1306
1359
struct test_dup : public test_case {
1307
1360
const ggml_type type;
@@ -1849,6 +1902,10 @@ struct test_mul_mat : public test_case {
1849
1902
return 5e-4 ;
1850
1903
}
1851
1904
1905
+ int64_t grad_nmax () override {
1906
+ return 20000 ;
1907
+ }
1908
+
1852
1909
uint64_t op_flops (ggml_tensor * t) override {
1853
1910
GGML_UNUSED (t);
1854
1911
return 2 * m * n * k * bs[0 ] * nr[0 ] * bs[1 ] * nr[1 ];
@@ -1878,8 +1935,12 @@ struct test_mul_mat : public test_case {
1878
1935
1879
1936
a = ggml_new_tensor_4d (ctx, type_a, ne_a[per[0 ]], ne_a[per[1 ]], ne_a[per[2 ]], ne_a[per[3 ]]);
1880
1937
b = ggml_new_tensor_4d (ctx, type_b, ne_b[per[0 ]], ne_b[per[1 ]], ne_b[per[2 ]], ne_b[per[3 ]]);
1881
- ggml_set_param (ctx, a);
1882
- ggml_set_param (ctx, b);
1938
+ if (!ggml_is_quantized (type_a)) {
1939
+ if (bs[1 ] == 1 && nr[1 ] == 1 ) {
1940
+ ggml_set_param (ctx, a);
1941
+ }
1942
+ ggml_set_param (ctx, b);
1943
+ }
1883
1944
ggml_set_name (a, " a" );
1884
1945
ggml_set_name (b, " b" );
1885
1946
@@ -1890,8 +1951,12 @@ struct test_mul_mat : public test_case {
1890
1951
} else {
1891
1952
a = ggml_new_tensor_4d (ctx, type_a, k, m, bs[0 ], bs[1 ]);
1892
1953
b = ggml_new_tensor_4d (ctx, type_b, k, n, bs[0 ]*nr[0 ], bs[1 ]*nr[1 ]);
1893
- ggml_set_param (ctx, a);
1894
- ggml_set_param (ctx, b);
1954
+ if (!ggml_is_quantized (type_a)) {
1955
+ if (bs[1 ] == 1 && nr[1 ] == 1 ) {
1956
+ ggml_set_param (ctx, a);
1957
+ }
1958
+ ggml_set_param (ctx, b);
1959
+ }
1895
1960
ggml_set_name (a, " a" );
1896
1961
ggml_set_name (b, " b" );
1897
1962
}
@@ -3798,6 +3863,16 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3798
3863
test_cases.emplace_back (new test_repeat (GGML_TYPE_I16, {10 , 5 , 4 , ne3}, {1 , 1 , 1 , 2 }));
3799
3864
}
3800
3865
3866
+ for (bool view : {false , true }) {
3867
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 1 }, view));
3868
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {2 , 1 , 1 , 1 }, view));
3869
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 2 , 1 , 1 }, view));
3870
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 2 , 1 }, view));
3871
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_F32, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 2 }, view));
3872
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_I32, {8 , 6 , 4 , 2 }, {2 , 1 , 1 , 1 }, view));
3873
+ test_cases.emplace_back (new test_repeat_back (GGML_TYPE_I16, {8 , 6 , 4 , 2 }, {1 , 1 , 1 , 2 }, view));
3874
+ }
3875
+
3801
3876
test_cases.emplace_back (new test_dup (GGML_TYPE_F32));
3802
3877
test_cases.emplace_back (new test_dup (GGML_TYPE_F16));
3803
3878
test_cases.emplace_back (new test_dup (GGML_TYPE_I32));
@@ -3919,21 +3994,25 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
3919
3994
for (ggml_type type_a : base_types) {
3920
3995
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
3921
3996
// test cases without permutation
3922
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , { 1 , 1 }, {1 , 1 }));
3923
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {1 , 1 }));
3924
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 1 }, {2 , 1 }));
3925
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 1 }));
3926
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 1 }));
3927
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {1 , 2 }));
3928
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {10 , 10 }, {2 , 2 }));
3929
-
3930
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , { 1 , 1 }, {1 , 1 }));
3931
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 1 }, {1 , 1 }));
3932
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 1 }, {2 , 1 }));
3933
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 1 }));
3934
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 1 }));
3935
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {1 , 2 }));
3936
- test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {10 , 10 }, {2 , 2 }));
3997
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {1 , 1 }, {1 , 1 }));
3998
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {1 , 1 }, {2 , 1 }));
3999
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {1 , 1 }, {1 , 2 }));
4000
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 1 }, {1 , 1 }));
4001
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 1 }, {2 , 1 }));
4002
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {1 , 1 }));
4003
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {2 , 1 }));
4004
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {1 , 2 }));
4005
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {3 , 2 }, {2 , 2 }));
4006
+
4007
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {1 , 1 }));
4008
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {2 , 1 }));
4009
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {1 , 1 }, {1 , 2 }));
4010
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 1 }, {1 , 1 }));
4011
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 1 }, {2 , 1 }));
4012
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {1 , 1 }));
4013
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {2 , 1 }));
4014
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {1 , 2 }));
4015
+ test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 16 , 256 , {3 , 2 }, {2 , 2 }));
3937
4016
3938
4017
// test cases with permutation
3939
4018
test_cases.emplace_back (new test_mul_mat (type_a, type_b, 16 , 1 , 256 , {2 , 3 }, {1 , 1 }, {0 , 2 , 1 , 3 }));
0 commit comments