@@ -989,26 +989,27 @@ void ggml_vk_mul_mat_mat_f32(kp::Sequence& seq,
989
989
nb1, nb2
990
990
};
991
991
992
+ const uint32_t local_x = ggml_vk_current_device ().subgroupSize ;
992
993
std::shared_ptr<kp::Algorithm> s_algo = nullptr ;
993
994
if (!komputeManager ()->hasAlgorithm (__func__)) {
994
- // std::cerr << "init f32 matmat shader" << std::endl;
995
- s_algo = komputeManager ()->algorithm <float , PushConstants>(__func__, s_kompute_context->pool .get (),
995
+ s_algo = komputeManager ()->algorithm <uint32_t , PushConstants>(__func__, s_kompute_context->pool .get (),
996
996
{inA, inB, out}, spirv,
997
997
{unsigned (ne01),
998
998
unsigned (ne11),
999
- unsigned (ne12)},
1000
- {},
999
+ unsigned (std::max (ne12, ne02))
1000
+ },
1001
+ {local_x},
1001
1002
{pushConsts});
1002
1003
} else {
1003
1004
s_algo = komputeManager ()->getAlgorithm (__func__);
1004
1005
s_algo->setTensors ({inA, inB, out});
1005
1006
s_algo->setWorkgroup ({unsigned (ne01),
1006
1007
unsigned (ne11),
1007
- unsigned (std::max (ne12, ne02))});
1008
+ unsigned (std::max (ne12, ne02)),
1009
+ });
1008
1010
s_algo->setPushConstants <PushConstants>({pushConsts});
1009
1011
s_algo->updateDescriptors (s_kompute_context->pool .get ());
1010
1012
}
1011
- // seq.record<kp::OpTensorFill>({out});
1012
1013
seq.record <kp::OpAlgoDispatch>(s_algo);
1013
1014
}
1014
1015
@@ -1038,15 +1039,16 @@ void ggml_vk_mul_mat_mat_f16(kp::Sequence& seq,
1038
1039
nb1, nb2
1039
1040
};
1040
1041
1042
+ const uint32_t local_x = ggml_vk_current_device ().subgroupSize ;
1041
1043
std::shared_ptr<kp::Algorithm> s_algo = nullptr ;
1042
1044
if (!komputeManager ()->hasAlgorithm (__func__)) {
1043
- s_algo = komputeManager ()->algorithm <float , PushConstants>(__func__, s_kompute_context->pool .get (),
1045
+ s_algo = komputeManager ()->algorithm <uint32_t , PushConstants>(__func__, s_kompute_context->pool .get (),
1044
1046
{inA, inB, out}, spirv,
1045
1047
{unsigned (ne01),
1046
1048
unsigned (ne11),
1047
1049
unsigned (std::max (ne12, ne02))
1048
1050
},
1049
- {},
1051
+ {local_x },
1050
1052
{pushConsts});
1051
1053
} else {
1052
1054
s_algo = komputeManager ()->getAlgorithm (__func__);
@@ -1141,7 +1143,7 @@ void ggml_vk_mul_mat_mat_q6_k(
1141
1143
if (!komputeManager ()->hasAlgorithm (__func__)) {
1142
1144
s_algo = komputeManager ()->algorithm <float , PushConstants>(__func__, s_kompute_context->pool .get (),
1143
1145
{inA, inB, out}, spirv,
1144
- {unsigned (ne01)/32 ,
1146
+ {unsigned (ne01)/256 ,
1145
1147
unsigned (ne11),
1146
1148
unsigned (std::max (ne12, ne02))
1147
1149
},
@@ -1150,7 +1152,7 @@ void ggml_vk_mul_mat_mat_q6_k(
1150
1152
} else {
1151
1153
s_algo = komputeManager ()->getAlgorithm (__func__);
1152
1154
s_algo->setTensors ({inA, inB, out});
1153
- s_algo->setWorkgroup ({unsigned (ne01)/32 ,
1155
+ s_algo->setWorkgroup ({unsigned (ne01)/256 ,
1154
1156
unsigned (ne11),
1155
1157
unsigned (std::max (ne12, ne02)),
1156
1158
});
@@ -1192,7 +1194,7 @@ void ggml_vk_mul_mat_mat_q4_x(const std::vector<uint32_t>& spirv,
1192
1194
{unsigned (ne01),
1193
1195
unsigned (ne11),
1194
1196
unsigned (std::max (ne12, ne02))},
1195
- {local_x, 4 },
1197
+ {local_x, 1 },
1196
1198
{pushConsts});
1197
1199
} else {
1198
1200
s_algo = komputeManager ()->getAlgorithm (__func__);
0 commit comments