diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index 687a7763ed..ab8d41fb0d 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -2,6 +2,9 @@ // // SPDX-License-Identifier: Apache-2.0 +#include +#include + #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/Value.h" @@ -23,8 +26,6 @@ #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTIR/IR/TTIR.h" #include "ttmlir/Dialect/TTIR/IR/TTIROps.h" -#include -#include using namespace mlir; using namespace mlir::tt; @@ -69,9 +70,9 @@ class StableHLOToTTIRReduceOpConversionPattern matchAndRewrite(mlir::stablehlo::ReduceOp srcOp, mlir::stablehlo::ReduceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); - if (not err.succeeded()) { - return err; + LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; } const mlir::Operation &innerOp = srcOp.getBody().front().front(); @@ -146,16 +147,16 @@ class StableHLOToTTIRTransposeOpConversionPattern matchAndRewrite(mlir::stablehlo::TransposeOp srcOp, mlir::stablehlo::TransposeOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; + } + auto outputType = mlir::cast( getTypeConverter()->convertType(srcOp.getResult().getType())); tensor::EmptyOp outputTensor = rewriter.create( srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); - if (not err.succeeded()) { - return err; - } - rewriter.replaceOpWithNewOp( srcOp, getTypeConverter()->convertType(outputTensor.getType()), Value(adaptor.getOperand()), Value(outputTensor), @@ -233,20 +234,20 @@ class StableHLOToTTIRDotGeneralOpConversionPattern matchAndRewrite(mlir::stablehlo::DotGeneralOp srcOp, mlir::stablehlo::DotGeneralOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto outputType = mlir::cast( - getTypeConverter()->convertType(srcOp.getResult().getType())); - tensor::EmptyOp outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - // This is a basic version that can only work for cases that can be directly // converted to matmul. The op should be extended as other ops such as // ttir.permute and ttir.broadcast_in_dim become available. - LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); - if (not err.succeeded()) { - return err; + LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; } + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp.getResult().getType())); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( srcOp, getTypeConverter()->convertType(outputTensor.getType()), adaptor.getLhs(), adaptor.getRhs(), Value(outputTensor), @@ -282,12 +283,12 @@ class StableHLOToTTIRDotGeneralOpConversionPattern srcOp, "Only non-transposed matmul is currently supported in TTIR."); } - if (not dimensions.getLhsBatchingDimensions().empty()) { + if (!dimensions.getLhsBatchingDimensions().empty()) { return rewriter.notifyMatchFailure( srcOp, "Only non-transposed matmul is currently supported in TTIR."); } - if (not dimensions.getRhsBatchingDimensions().empty()) { + if (!dimensions.getRhsBatchingDimensions().empty()) { return rewriter.notifyMatchFailure( srcOp, "Only non-transposed matmul is currently supported in TTIR."); } @@ -301,46 +302,58 @@ class StableHLOToTTIRConstantOpConversionPattern using OpConversionPattern::OpConversionPattern; - mlir::ElementsAttr get1DTensor(mlir::stablehlo::ConstantOp srcOp) const { - auto outputType = mlir::cast( - getTypeConverter()->convertType(srcOp.getResult().getType())); - - assert(outputType.getRank() == 1 && - "Should only be called if constant is scalar."); - mlir::ElementsAttr elements; - if (auto floatAttr = - mlir::cast(srcOp.getValue())) { - std::vector floatValues( - floatAttr.getValues().begin(), - floatAttr.getValues().end()); - elements = mlir::DenseFPElementsAttr::get(outputType, floatValues); - } else if (auto intAttr = - mlir::cast(srcOp.getValue())) { - std::vector intValues( - intAttr.getValues().begin(), - intAttr.getValues().end()); - elements = mlir::DenseIntElementsAttr::get(outputType, intValues); - } else { - assert(false && "Unsupported data type"); - } - return elements; - } - public: LogicalResult matchAndRewrite(mlir::stablehlo::ConstantOp srcOp, mlir::stablehlo::ConstantOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { + LogicalResult legalityResult = checkBasicLegality(srcOp, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; + } + auto outputType = mlir::cast( getTypeConverter()->convertType(srcOp.getResult().getType())); - mlir::ElementsAttr newValue = - outputType.getRank() == 1 ? get1DTensor(srcOp) : srcOp.getValue(); + // Scalar tensors are not supported by TTIR so we have to convert them to + // 1-D tensors. + mlir::ElementsAttr valueAttr = + srcOp.getValue().getShapedType().getShape().empty() + ? convertTo1DTensor(srcOp.getValue()) + : srcOp.getValue(); rewriter.replaceOpWithNewOp(srcOp, outputType, - newValue); + valueAttr); + return success(); + } + +private: + LogicalResult checkBasicLegality(mlir::stablehlo::ConstantOp &srcOp, + ConversionPatternRewriter &rewriter) const { + if (srcOp.getValue().getShapedType().getShape().empty() && + !srcOp.getValue().getElementType().isIntOrFloat()) { + return rewriter.notifyMatchFailure(srcOp, "Unsupported element type."); + } + return success(); } + + mlir::ElementsAttr convertTo1DTensor(mlir::ElementsAttr valueAttr) const { + mlir::ShapedType valueType = mlir::cast( + getTypeConverter()->convertType(valueAttr.getShapedType())); + if (valueAttr.getElementType().isInteger()) { + return mlir::DenseElementsAttr::get(valueType, + valueAttr.getSplatValue()); + } else { + // In case of float values llvm has a bug where not all float types are + // supported for iterating in DenseElementsAttr, so we have to use a + // different constructor. + std::vector floatValues( + valueAttr.getValues().begin(), + valueAttr.getValues().end()); + return mlir::DenseElementsAttr::get(valueType, floatValues); + } + } }; class StableHLOToTTIRConvolutionOpConversionPattern @@ -558,10 +571,9 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern matchAndRewrite(mlir::stablehlo::BroadcastInDimOp srcOp, mlir::stablehlo::BroadcastInDimOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - - LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); - if (not err.succeeded()) { - return err; + LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; } auto outputType = mlir::cast( @@ -760,9 +772,9 @@ class StableHLOToTTIROpLogicalOpConversionPattern LogicalResult matchAndRewrite(SrcOp srcOp, Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter); - if (not err.succeeded()) { - return err; + LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; } auto outputType = mlir::cast( diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir index 72fb2eef49..6e5aea1e20 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir @@ -3,13 +3,29 @@ module @jit_constant attributes {} { func.func public @test_splat() -> tensor<64xf32> { %0 = stablehlo.constant dense<0.3> : tensor<64xf32> - // CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]] + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<64xf32>}> : () -> tensor<64xf32> return %0 : tensor<64xf32> } func.func public @test_multiple() -> tensor<2x2xf32> { %0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32> - // CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]] + // The ugly regex after `dense` is necessary because double square opening + // brackets indicate substitution block in FileCheck syntax. + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> return %0 : tensor<2x2xf32> } + + func.func public @test_scalar_int() -> tensor { + %0 = stablehlo.constant dense<3> : tensor + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32> + return %0 : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xi32> + } + + func.func public @test_scalar_float() -> tensor { + %0 = stablehlo.constant dense<0.3> : tensor + // CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32> + return %0 : tensor + // CHECK: return %{{[0-9]+}} : tensor<1xf32> + } }