@@ -162,6 +162,7 @@ struct vk_device_struct {
162
162
uint32_t subgroup_size;
163
163
uint32_t shader_core_count;
164
164
bool uma;
165
+ bool float_controls_rte_fp16;
165
166
bool coopmat2;
166
167
167
168
bool coopmat_support;
@@ -1916,17 +1917,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
1916
1917
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 3, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
1917
1918
1918
1919
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1919
- ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1920
-
1921
1920
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1922
- ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1921
+
1922
+ if (device->float_controls_rte_fp16) {
1923
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1924
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1925
+ } else {
1926
+ ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1927
+ ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
1928
+ }
1923
1929
1924
1930
ggml_vk_create_pipeline(device, device->pipeline_argsort_f32, "argsort_f32", argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1024, 1, 1}, {}, 1);
1925
1931
1926
1932
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
1927
1933
1928
1934
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1929
- ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1935
+ if (device->float_controls_rte_fp16) {
1936
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1937
+ } else {
1938
+ ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {256, 1, 1}, {}, 1);
1939
+ }
1930
1940
1931
1941
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
1932
1942
@@ -2007,11 +2017,13 @@ static vk_device ggml_vk_get_device(size_t idx) {
2007
2017
vk::PhysicalDeviceDriverProperties driver_props;
2008
2018
vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props;
2009
2019
vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props;
2020
+ vk::PhysicalDeviceVulkan12Properties vk12_props;
2010
2021
props2.pNext = &props3;
2011
2022
props3.pNext = &subgroup_props;
2012
2023
subgroup_props.pNext = &driver_props;
2024
+ driver_props.pNext = &vk12_props;
2013
2025
2014
- VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props ;
2026
+ VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props ;
2015
2027
2016
2028
if (maintenance4_support) {
2017
2029
last_struct->pNext = (VkBaseOutStructure *)&props4;
@@ -2057,6 +2069,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
2057
2069
} else {
2058
2070
device->shader_core_count = 0;
2059
2071
}
2072
+ device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16;
2060
2073
2061
2074
const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr;
2062
2075
0 commit comments