diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 3db33aee1f1c..940b88ab02ed 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -568,36 +568,92 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); } + bool isChannelQuantizationForLhs = false; if (auto zpTy = dyn_cast(lhsZp.getType())) { - for (auto dim : zpTy.getSizes()) - if (dim != 1) + auto lhsZpSize = zpTy.getSizes(); + if (lhsZpSize.size() == 0 || + llvm::all_of(lhsZpSize, [](int64_t d) { return d == 1; })) { + lhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), lhsZp); + } else if (lhsZpSize.size() == 1) { + auto lhsSize = lhsTy.getSizes(); + if (lhsSize.size() != 2 || lhsSize[0] != lhsZpSize[0]) return failure(); - lhsZp = rewriter.create( - binder.getLoc(), rewriter.getType(), lhsZp); + isChannelQuantizationForLhs = true; + } else { + return failure(); + } } + bool isChannelQuantizationForRhs = false; if (auto zpTy = dyn_cast(rhsZp.getType())) { - for (auto dim : zpTy.getSizes()) - if (dim != 1) + auto rhsZpSize = zpTy.getSizes(); + if (rhsZpSize.size() == 0 || + llvm::all_of(rhsZpSize, [](int64_t d) { return d == 1; })) { + rhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), rhsZp); + } else if (rhsZpSize.size() == 1) { + auto rhsSize = rhsTy.getSizes(); + if (rhsSize.size() != 2 || rhsSize[1] != rhsZpSize[0]) return failure(); - rhsZp = rewriter.create( - binder.getLoc(), rewriter.getType(), rhsZp); + isChannelQuantizationForRhs = true; + } else { + return failure(); + } } - Value scale = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getF64FloatAttr(1.0)); - auto lhsQTy = getQTorchTypeFromTorchIntType(lhsTy); auto rhsQTy = getQTorchTypeFromTorchIntType(rhsTy); if (!lhsQTy || !rhsQTy) return rewriter.notifyMatchFailure(binder.op, "failed to get qtype"); - lhs = rewriter.create( - binder.getLoc(), lhsQTy, lhs, scale, lhsZp); - rhs = rewriter.create( - binder.getLoc(), rhsQTy, rhs, scale, rhsZp); + Value f32Ty = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + (int64_t)torch_upstream::ScalarType::Float)); + Value none = rewriter.create(binder.getLoc()); + + if (isChannelQuantizationForLhs) { + Value axis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + Torch::ValueTensorType lhsZpTy = + dyn_cast(lhsZp.getType()); + Type scaleTy = lhsZpTy.getWithSizesAndDtype(lhsZpTy.getSizes(), + rewriter.getF32Type()); + Value scale = rewriter.create( + binder.getLoc(), scaleTy, /*self=*/lhsZp, f32Ty, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + lhs = rewriter.create( + binder.getLoc(), lhsQTy, lhs, scale, lhsZp, axis); + } else { + Value scale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + lhs = rewriter.create( + binder.getLoc(), lhsQTy, lhs, scale, lhsZp); + } + + if (isChannelQuantizationForRhs) { + Value axis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(1)); + Torch::ValueTensorType rhsZpTy = + dyn_cast(rhsZp.getType()); + Type scaleTy = rhsZpTy.getWithSizesAndDtype(rhsZpTy.getSizes(), + rewriter.getF32Type()); + Value scale = rewriter.create( + binder.getLoc(), scaleTy, /*self=*/rhsZp, f32Ty, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + rhs = rewriter.create( + binder.getLoc(), rhsQTy, rhs, scale, rhsZp, axis); + } else { + Value scale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + rhs = rewriter.create( + binder.getLoc(), rhsQTy, rhs, scale, rhsZp); + } rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, rhs); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index b2c718bceace..8958db34db14 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -569,9 +569,10 @@ func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vt %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32> // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2 // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3 - // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00 - // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> - // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> + // CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> + // CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]] // CHECK: return %[[MM]] return %0 : !torch.vtensor<[4,2],si32> @@ -584,13 +585,59 @@ func.func @test_matmulinteger_batched(%arg0: !torch.vtensor<[7,4,3],ui8>, %arg1: %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[7,4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[7,4,2],si32> // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2 // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3 - // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00 - // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8> - // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> + // CHECK: %[[L_SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[L_SCALE]], %[[LITEM]] : !torch.vtensor<[7,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[7,4,3],!torch.quint8> + // CHECK: %[[R_SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[R_SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> // CHECK: %[[MM:.+]] = torch.aten.matmul %[[LMAKE]], %[[RMAKE]] // CHECK: return %[[MM]] return %0 : !torch.vtensor<[7,4,2],si32> } + +// ----- + +// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_lhsZp( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[16,2],ui8>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[16],ui8>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} { +func.func @test_matmulinteger_non_scalar_lhsZp(%arg0: !torch.vtensor<[16, 2],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[16],ui8>, %arg3: !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} { + // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_3]] : !torch.vtensor<[],si8> -> !torch.int + // CHECK: %[[VAL_5:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_6:.*]] = torch.constant.none + // CHECK: %[[VAL_7:.*]] = torch.constant.int 0 + // CHECK: %[[VAL_8:.*]] = torch.aten.ones_like %[[VAL_2]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[16],ui8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[16],f32> + // CHECK: %[[VAL_9:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_0]], %[[VAL_8]], %[[VAL_2]], %[[VAL_7]] : !torch.vtensor<[16,2],ui8>, !torch.vtensor<[16],f32>, !torch.vtensor<[16],ui8>, !torch.int -> !torch.vtensor<[16,2],!torch.quint8> + // CHECK: %[[VAL_10:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[VAL_11:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_4]] : !torch.vtensor<[2,768],si8>, !torch.float, !torch.int -> !torch.vtensor<[2,768],!torch.qint8> + // CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_9]], %[[VAL_11]] : !torch.vtensor<[16,2],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[16,768],si32> + // CHECK: return %[[VAL_12]] : !torch.vtensor<[16,768],si32> + %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[16,2],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[16],ui8>, !torch.vtensor<[],si8>) -> !torch.vtensor<[16,768],si32> + return %0 : !torch.vtensor<[16,768],si32> +} + +// ----- + +// CHECK-LABEL: func.func @test_matmulinteger_non_scalar_rhsZp( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],ui8>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[2,768],si8>, +// CHECK-SAME: %[[VAL_2:.*]]: !torch.vtensor<[],ui8>, +// CHECK-SAME: %[[VAL_3:.*]]: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_met +func.func @test_matmulinteger_non_scalar_rhsZp(%arg0: !torch.vtensor<[?,?],ui8>, %arg1: !torch.vtensor<[2,768],si8>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "pytorch", torch.onnx_meta.producer_version = "0.1.0"} { + // CHECK: %[[VAL_4:.*]] = torch.aten.item %[[VAL_2]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[VAL_5:.*]] = torch.constant.int 6 + // CHECK: %[[VAL_6:.*]] = torch.constant.none + // CHECK: %[[VAL_7:.*]] = torch.constant.float 1.000000e+00 + // CHECK: %[[VAL_8:.*]] = torch.aten._make_per_tensor_quantized_tensor %[[VAL_0]], %[[VAL_7]], %[[VAL_4]] : !torch.vtensor<[?,?],ui8>, !torch.float, !torch.int -> !torch.vtensor<[?,?],!torch.quint8> + // CHECK: %[[VAL_9:.*]] = torch.constant.int 1 + // CHECK: %[[VAL_10:.*]] = torch.aten.ones_like %[[VAL_3]], %[[VAL_5]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]], %[[VAL_6]] : !torch.vtensor<[768],si8>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[768],f32> + // CHECK: %[[VAL_11:.*]] = torch.aten._make_per_channel_quantized_tensor %[[VAL_1]], %[[VAL_10]], %[[VAL_3]], %[[VAL_9]] : !torch.vtensor<[2,768],si8>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],si8>, !torch.int -> !torch.vtensor<[2,768],!torch.qint8> + // CHECK: %[[VAL_12:.*]] = torch.aten.matmul %[[VAL_8]], %[[VAL_11]] : !torch.vtensor<[?,?],!torch.quint8>, !torch.vtensor<[2,768],!torch.qint8> -> !torch.vtensor<[?,768],si32> + // CHECK: return %[[VAL_12]] : !torch.vtensor<[?,768],si32> + %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[?,?],ui8>, !torch.vtensor<[2,768],si8>, !torch.vtensor<[],ui8>, !torch.vtensor<[768],si8>) -> !torch.vtensor<[?,768],si32> + return %0 : !torch.vtensor<[?,768],si32> +} + // ----- // CHECK-LABEL: func.func @test_mul