Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue #798 - Constant OP conversion doesn't convert scalar values #802

Merged
merged 1 commit into from
Oct 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 67 additions & 55 deletions lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
//
// SPDX-License-Identifier: Apache-2.0

#include <algorithm>
#include <vector>

#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/Value.h"
Expand All @@ -23,8 +26,6 @@
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTIR/IR/TTIR.h"
#include "ttmlir/Dialect/TTIR/IR/TTIROps.h"
#include <algorithm>
#include <vector>

using namespace mlir;
using namespace mlir::tt;
Expand Down Expand Up @@ -69,9 +70,9 @@ class StableHLOToTTIRReduceOpConversionPattern
matchAndRewrite(mlir::stablehlo::ReduceOp srcOp,
mlir::stablehlo::ReduceOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

const mlir::Operation &innerOp = srcOp.getBody().front().front();
Expand Down Expand Up @@ -146,16 +147,16 @@ class StableHLOToTTIRTransposeOpConversionPattern
matchAndRewrite(mlir::stablehlo::TransposeOp srcOp,
mlir::stablehlo::TransposeOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
}

rewriter.replaceOpWithNewOp<mlir::tt::ttir::TransposeOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
Value(adaptor.getOperand()), Value(outputTensor),
Expand Down Expand Up @@ -233,20 +234,20 @@ class StableHLOToTTIRDotGeneralOpConversionPattern
matchAndRewrite(mlir::stablehlo::DotGeneralOp srcOp,
mlir::stablehlo::DotGeneralOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

// This is a basic version that can only work for cases that can be directly
// converted to matmul. The op should be extended as other ops such as
// ttir.permute and ttir.broadcast_in_dim become available.

LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));
tensor::EmptyOp outputTensor = rewriter.create<tensor::EmptyOp>(
srcOp.getLoc(), outputType.getShape(), outputType.getElementType());

rewriter.replaceOpWithNewOp<mlir::tt::ttir::MatmulOp>(
srcOp, getTypeConverter()->convertType(outputTensor.getType()),
adaptor.getLhs(), adaptor.getRhs(), Value(outputTensor),
Expand Down Expand Up @@ -282,12 +283,12 @@ class StableHLOToTTIRDotGeneralOpConversionPattern
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}

if (not dimensions.getLhsBatchingDimensions().empty()) {
if (!dimensions.getLhsBatchingDimensions().empty()) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}

if (not dimensions.getRhsBatchingDimensions().empty()) {
if (!dimensions.getRhsBatchingDimensions().empty()) {
return rewriter.notifyMatchFailure(
srcOp, "Only non-transposed matmul is currently supported in TTIR.");
}
Expand All @@ -301,46 +302,58 @@ class StableHLOToTTIRConstantOpConversionPattern

using OpConversionPattern<mlir::stablehlo::ConstantOp>::OpConversionPattern;

mlir::ElementsAttr get1DTensor(mlir::stablehlo::ConstantOp srcOp) const {
auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

assert(outputType.getRank() == 1 &&
"Should only be called if constant is scalar.");
mlir::ElementsAttr elements;
if (auto floatAttr =
mlir::cast<mlir::DenseFPElementsAttr>(srcOp.getValue())) {
std::vector<mlir::APFloat> floatValues(
floatAttr.getValues<mlir::APFloat>().begin(),
floatAttr.getValues<mlir::APFloat>().end());
elements = mlir::DenseFPElementsAttr::get(outputType, floatValues);
} else if (auto intAttr =
mlir::cast<mlir::DenseIntElementsAttr>(srcOp.getValue())) {
std::vector<mlir::APInt> intValues(
intAttr.getValues<mlir::APInt>().begin(),
intAttr.getValues<mlir::APInt>().end());
elements = mlir::DenseIntElementsAttr::get(outputType, intValues);
} else {
assert(false && "Unsupported data type");
}
return elements;
}

public:
LogicalResult
matchAndRewrite(mlir::stablehlo::ConstantOp srcOp,
mlir::stablehlo::ConstantOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult legalityResult = checkBasicLegality(srcOp, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
getTypeConverter()->convertType(srcOp.getResult().getType()));

mlir::ElementsAttr newValue =
outputType.getRank() == 1 ? get1DTensor(srcOp) : srcOp.getValue();
// Scalar tensors are not supported by TTIR so we have to convert them to
// 1-D tensors.
mlir::ElementsAttr valueAttr =
srcOp.getValue().getShapedType().getShape().empty()
? convertTo1DTensor(srcOp.getValue())
: srcOp.getValue();

rewriter.replaceOpWithNewOp<mlir::tt::ttir::ConstantOp>(srcOp, outputType,
newValue);
valueAttr);
return success();
}

private:
LogicalResult checkBasicLegality(mlir::stablehlo::ConstantOp &srcOp,
ConversionPatternRewriter &rewriter) const {
if (srcOp.getValue().getShapedType().getShape().empty() &&
!srcOp.getValue().getElementType().isIntOrFloat()) {
return rewriter.notifyMatchFailure(srcOp, "Unsupported element type.");
}

return success();
}

mlir::ElementsAttr convertTo1DTensor(mlir::ElementsAttr valueAttr) const {
mlir::ShapedType valueType = mlir::cast<mlir::ShapedType>(
getTypeConverter()->convertType(valueAttr.getShapedType()));
if (valueAttr.getElementType().isInteger()) {
return mlir::DenseElementsAttr::get<int>(valueType,
valueAttr.getSplatValue<int>());
} else {
// In case of float values llvm has a bug where not all float types are
// supported for iterating in DenseElementsAttr, so we have to use a
// different constructor.
std::vector<mlir::APFloat> floatValues(
valueAttr.getValues<mlir::APFloat>().begin(),
valueAttr.getValues<mlir::APFloat>().end());
return mlir::DenseElementsAttr::get(valueType, floatValues);
}
}
};

class StableHLOToTTIRConvolutionOpConversionPattern
Expand Down Expand Up @@ -558,10 +571,9 @@ class StableHLOToTTIRBroadcastInDimOpConversionPattern
matchAndRewrite(mlir::stablehlo::BroadcastInDimOp srcOp,
mlir::stablehlo::BroadcastInDimOp::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {

LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
Expand Down Expand Up @@ -760,9 +772,9 @@ class StableHLOToTTIROpLogicalOpConversionPattern
LogicalResult
matchAndRewrite(SrcOp srcOp, Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
LogicalResult err = checkBasicLegality(srcOp, adaptor, rewriter);
if (not err.succeeded()) {
return err;
LogicalResult legalityResult = checkBasicLegality(srcOp, adaptor, rewriter);
if (!legalityResult.succeeded()) {
return legalityResult;
}

auto outputType = mlir::cast<RankedTensorType>(
Expand Down
20 changes: 18 additions & 2 deletions test/ttmlir/Conversion/StableHLOToTTIR/constant_op.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,29 @@
module @jit_constant attributes {} {
func.func public @test_splat() -> tensor<64xf32> {
%0 = stablehlo.constant dense<0.3> : tensor<64xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<64xf32>}> : () -> tensor<64xf32>
return %0 : tensor<64xf32>
}

func.func public @test_multiple() -> tensor<2x2xf32> {
%0 = stablehlo.constant dense<[[0.0, 1.0], [2.0, 3.0]]> : tensor<2x2xf32>
// CHECK: %[[C:.*]] = "ttir.constant"[[C:.*]]
// The ugly regex after `dense` is necessary because double square opening
// brackets indicate substitution block in FileCheck syntax.
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<{{([[])}}[0.000000e+00, 1.000000e+00], [2.000000e+00, 3.000000e+00]]> : tensor<2x2xf32>}> : () -> tensor<2x2xf32>
return %0 : tensor<2x2xf32>
}

func.func public @test_scalar_int() -> tensor<i32> {
%0 = stablehlo.constant dense<3> : tensor<i32>
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3> : tensor<1xi32>}> : () -> tensor<1xi32>
return %0 : tensor<i32>
// CHECK: return %{{[0-9]+}} : tensor<1xi32>
}

func.func public @test_scalar_float() -> tensor<f32> {
%0 = stablehlo.constant dense<0.3> : tensor<f32>
// CHECK: %{{[0-9]+}} = "ttir.constant"() <{value = dense<3.000000e-01> : tensor<1xf32>}> : () -> tensor<1xf32>
return %0 : tensor<f32>
// CHECK: return %{{[0-9]+}} : tensor<1xf32>
}
}
Loading