File tree 1 file changed +9
-6
lines changed
1 file changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -173,12 +173,15 @@ struct inner_product_forward : public dnnl::inner_product_forward {
173
173
IDEEP_ENFORCE (utils::one_of (weights.get_data_type (),
174
174
data_type::f32, data_type::bf16),
175
175
" Incorrect data type in weights" );
176
-
177
- // align weights data type with src
178
- dst_data_type = src.get_data_type () == data_type::bf16 ? data_type::bf16
179
- : data_type::f32;
180
- src_desc = src.get_desc ().to_type (dst_data_type);
181
- weights_desc = weights.get_desc ().to_type (dst_data_type);
176
+ if (dst.is_empty ()) {
177
+ // align weights data type with src
178
+ dst_data_type = src.get_data_type () == data_type::bf16 ? data_type::bf16
179
+ : data_type::f32;
180
+ } else {
181
+ dst_data_type = dst.get_data_type ();
182
+ }
183
+ src_desc = src.get_desc ().to_type (src.get_data_type ());
184
+ weights_desc = weights.get_desc ().to_type (src.get_data_type ());
182
185
if (with_bias) {
183
186
IDEEP_ENFORCE (utils::one_of (bias.get_data_type (),
184
187
data_type::f32, data_type::bf16),
You can’t perform that action at this time.
0 commit comments