Skip to content

Commit c922d9e

Browse files
liangan1EikanWang
authored andcommitted
Fix inner_product dst datatype issue (#365)
1 parent 9f0a0b4 commit c922d9e

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

ideep/ideep/operators/inner_product.hpp

+9-6
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,15 @@ struct inner_product_forward : public dnnl::inner_product_forward {
173173
IDEEP_ENFORCE(utils::one_of(weights.get_data_type(),
174174
data_type::f32, data_type::bf16),
175175
"Incorrect data type in weights");
176-
177-
// align weights data type with src
178-
dst_data_type = src.get_data_type() == data_type::bf16 ? data_type::bf16
179-
: data_type::f32;
180-
src_desc = src.get_desc().to_type(dst_data_type);
181-
weights_desc = weights.get_desc().to_type(dst_data_type);
176+
if (dst.is_empty()) {
177+
// align weights data type with src
178+
dst_data_type = src.get_data_type() == data_type::bf16 ? data_type::bf16
179+
: data_type::f32;
180+
} else {
181+
dst_data_type = dst.get_data_type();
182+
}
183+
src_desc = src.get_desc().to_type(src.get_data_type());
184+
weights_desc = weights.get_desc().to_type(src.get_data_type());
182185
if (with_bias) {
183186
IDEEP_ENFORCE(utils::one_of(bias.get_data_type(),
184187
data_type::f32, data_type::bf16),

0 commit comments

Comments
 (0)