Skip to content

Commit beb950e

Browse files
yufengleeDu Li
authored and
Du Li
committed
Fuse MatMulIntegerToFloat only when scales are scalar (microsoft#6008)
MatMulIntegerToFloat fusion fuses per-row and per-column MatMulInteger, which is not supported by the MatMulIntegerToFloat kernel now. Limit the fusion to per-matrix only before we supporting the per-channel fully.
1 parent 3e39b38 commit beb950e

File tree

6 files changed

+21
-6
lines changed

6 files changed

+21
-6
lines changed

onnxruntime/core/optimizer/matmul_integer_to_float.cc

+8-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ static bool CheckBiasShape(const TensorShapeProto* bias_shape) {
3434
/**
3535
MatMulIntegerToFloatFusion will fuse subgraph like below into MatMulIntegerToFloat:
3636
37-
A A_Zero B B_Zero A_Scale) B_Scale Bias (Const, Optional)
37+
A A_Zero B B_Zero A_Scale B_Scale Bias (Const, Optional)
3838
\ | | / \ / |
3939
\ | | / \ / |
4040
\ | | / \ / |
@@ -84,6 +84,13 @@ Status MatMulIntegerToFloatFusion::ApplyImpl(Graph& graph, bool& modified, int g
8484
continue;
8585
}
8686

87+
// A_Scale is scalar and B_Scale is scalar or 1D tensor
88+
auto mul_node_input_defs = p_mul_node_right->InputDefs();
89+
if (!optimizer_utils::IsScalar(*mul_node_input_defs[0]) ||
90+
!optimizer_utils::IsScalar(*mul_node_input_defs[1])) {
91+
continue;
92+
}
93+
8794
Node& cast_node = *graph.GetNode(p_cast_node->Index());
8895
Node& matmulinteger_node = *graph.GetNode(p_matmulinteger_node->Index());
8996
Node& mul_node_right = *graph.GetNode(p_mul_node_right->Index());

onnxruntime/core/optimizer/utils.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto) {
2424
return tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 || tensor_proto.data_type() == ONNX_NAMESPACE::TensorProto_DataType_DOUBLE;
2525
}
2626

27-
inline bool IsScalar(const NodeArg& input_arg) {
27+
bool IsScalar(const NodeArg& input_arg) {
2828
auto shape = input_arg.Shape();
2929
if (shape == nullptr) {
3030
// shape inferencing wasn't able to populate shape information for this NodeArg

onnxruntime/core/optimizer/utils.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ namespace optimizer_utils {
1515
// Check if TensorProto contains a floating point type.
1616
bool IsFloatingPointDataType(const ONNX_NAMESPACE::TensorProto& tensor_proto);
1717

18+
// Check if input is a scalar
19+
bool IsScalar(const NodeArg& input_arg);
20+
1821
/** Check whether a input is initializer with specified float value.
1922
@param expected_value is the expected value of the initializer.
2023
@param is_constant means whether the initializer is required to be constant.
@@ -60,7 +63,7 @@ bool ValidateShape(const NodeArg& node_arg, const std::initializer_list<int64_t>
6063
*/
6164
bool CompareShape(const ONNX_NAMESPACE::TensorShapeProto& node_arg_shape, const ONNX_NAMESPACE::TensorShapeProto& node_arg_other_shape);
6265

63-
/** Check check whether each dimension is known for shape of node_arg
66+
/** Check whether each dimension is known for shape of node_arg
6467
@returns false when shape is nullptr, or total dimension is not same as expected_dim_size length,
6568
or any dim is unknown (without dim value).
6669
*/

onnxruntime/test/optimizer/graph_transform_test.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -3069,9 +3069,9 @@ TEST_F(GraphTransformationTests, MatMulIntegerToFloatTest) {
30693069

30703070
std::map<std::string, int> op_to_count = CountOpsInGraph(graph);
30713071
EXPECT_EQ(op_to_count["DynamicQuantizeLinear"], 1);
3072-
EXPECT_EQ(op_to_count["MatMulInteger"], 0);
3073-
EXPECT_EQ(op_to_count["Cast"], 0);
3074-
EXPECT_EQ(op_to_count["Mul"], 0);
3072+
EXPECT_EQ(op_to_count["MatMulInteger"], 1);
3073+
EXPECT_EQ(op_to_count["Cast"], 1);
3074+
EXPECT_EQ(op_to_count["Mul"], 2);
30753075
EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 3);
30763076
EXPECT_EQ(op_to_count["Add"], 1);
30773077
}
Binary file not shown.

onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def GenerateModel(model_name):
2929
nodes.extend(MakeSubGraph("_1", True))
3030
nodes.extend(MakeSubGraph("_2", True))
3131
nodes.extend(MakeSubGraph("_3", False))
32+
nodes.extend(MakeSubGraph("_4", False))
3233

3334
initializers = []
3435
initializers.extend(MakeInitializer("_1"))
@@ -48,11 +49,15 @@ def GenerateModel(model_name):
4849
helper.make_tensor_value_info('b_quantized_2', TensorProto.UINT8, [2, 3]),
4950
helper.make_tensor_value_info('b_zp_2', TensorProto.UINT8, [1]),
5051
helper.make_tensor_value_info('b_scale_2', TensorProto.FLOAT, [1]),
52+
helper.make_tensor_value_info('b_quantized_4', TensorProto.UINT8, [2, 3]),
53+
helper.make_tensor_value_info('b_zp_4', TensorProto.UINT8, [3]),
54+
helper.make_tensor_value_info('b_scale_4', TensorProto.FLOAT, [3]),
5155
],
5256
[ # outputs
5357
helper.make_tensor_value_info('output_1', TensorProto.FLOAT, [3, 3]),
5458
helper.make_tensor_value_info('output_2', TensorProto.FLOAT, [3, 3]),
5559
helper.make_tensor_value_info('output_3', TensorProto.FLOAT, [3, 3]),
60+
helper.make_tensor_value_info('output_4', TensorProto.FLOAT, [3, 3]),
5661
],
5762
initializers)
5863

0 commit comments

Comments
 (0)