Skip to content

[AutoBump] Merge with fixes of f4943464 (Jan 18) (2) Needs onnx-mlir bump #534

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: feature/fused-ops
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions mlir/include/mlir/IR/Builders.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,6 @@ class Builder {
Attribute metadata = Attribute());

// Types.
FloatType getFloat4E2M1FNType();
FloatType getFloat6E2M3FNType();
FloatType getFloat6E3M2FNType();
FloatType getFloat8E5M2Type();
FloatType getFloat8E4M3Type();
FloatType getFloat8E4M3FNType();
FloatType getFloat8E5M2FNUZType();
FloatType getFloat8E4M3FNUZType();
FloatType getFloat8E4M3B11FNUZType();
FloatType getFloat8E3M4Type();
FloatType getFloat8E8M0FNUType();
FloatType getBF16Type();
FloatType getF16Type();
FloatType getTF32Type();
Expand Down
20 changes: 13 additions & 7 deletions mlir/include/mlir/IR/BuiltinTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class Builtin_FloatType<string name, string mnemonic,
DeclareTypeInterfaceMethods<
FloatTypeInterface,
["getFloatSemantics"] # declaredInterfaceMethods>]> {
}

// Float types that are cached in MLIRContext.
class Builtin_CachedFloatType<string name, string mnemonic,
list<string> declaredInterfaceMethods = []>
: Builtin_FloatType<name, mnemonic, declaredInterfaceMethods> {
let extraClassDeclaration = [{
static }] # name # [{Type get(MLIRContext *context);
}];
Expand Down Expand Up @@ -326,52 +332,52 @@ def Builtin_Float8E8M0FNU : Builtin_FloatType<"Float8E8M0FNU", "f8E8M0FNU"> {
//===----------------------------------------------------------------------===//
// BFloat16Type

def Builtin_BFloat16 : Builtin_FloatType<"BFloat16", "bf16",
def Builtin_BFloat16 : Builtin_CachedFloatType<"BFloat16", "bf16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "bfloat16 floating-point type";
}

//===----------------------------------------------------------------------===//
// Float16Type

def Builtin_Float16 : Builtin_FloatType<"Float16", "f16",
def Builtin_Float16 : Builtin_CachedFloatType<"Float16", "f16",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "16-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// FloatTF32Type

def Builtin_FloatTF32 : Builtin_FloatType<"FloatTF32", "tf32"> {
def Builtin_FloatTF32 : Builtin_CachedFloatType<"FloatTF32", "tf32"> {
let summary = "TF32 floating-point type";
}

//===----------------------------------------------------------------------===//
// Float32Type

def Builtin_Float32 : Builtin_FloatType<"Float32", "f32",
def Builtin_Float32 : Builtin_CachedFloatType<"Float32", "f32",
/*declaredInterfaceMethods=*/["scaleElementBitwidth"]> {
let summary = "32-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// Float64Type

def Builtin_Float64 : Builtin_FloatType<"Float64", "f64"> {
def Builtin_Float64 : Builtin_CachedFloatType<"Float64", "f64"> {
let summary = "64-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// Float80Type

def Builtin_Float80 : Builtin_FloatType<"Float80", "f80"> {
def Builtin_Float80 : Builtin_CachedFloatType<"Float80", "f80"> {
let summary = "80-bit floating-point type";
}

//===----------------------------------------------------------------------===//
// Float128Type

def Builtin_Float128 : Builtin_FloatType<"Float128", "f128"> {
def Builtin_Float128 : Builtin_CachedFloatType<"Float128", "f128"> {
let summary = "128-bit floating-point type";
}

Expand Down
26 changes: 13 additions & 13 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -330,31 +330,31 @@ def F80 : F<80>;
def F128 : F<128>;

def BF16 : Type<CPred<"$_self.isBF16()">, "bfloat16 type">,
BuildableType<"$_builder.getBF16Type()">;
BuildableType<"$_builder.getType<BFloat16Type>()">;
def TF32 : Type<CPred<"$_self.isTF32()">, "tf32 type">,
BuildableType<"$_builder.getTF32Type()">;
BuildableType<"$_builder.getType<FloatTF32Type>()">;
def F8E4M3FN : Type<CPred<"$_self.isFloat8E4M3FN()">, "f8E4M3FN type">,
BuildableType<"$_builder.getFloat8E4M3FNType()">;
BuildableType<"$_builder.getType<Float8E4M3FNType>()">;
def F8E5M2 : Type<CPred<"$_self.isFloat8E5M2()">, "f8E5M2 type">,
BuildableType<"$_builder.getFloat8E5M2Type()">;
BuildableType<"$_builder.getType<Float8E5M2Type>()">;
def F8E4M3 : Type<CPred<"$_self.isFloat8E4M3()">, "f8E4M3 type">,
BuildableType<"$_builder.getFloat8E4M3Type()">;
BuildableType<"$_builder.getType<Float8E4M3Type>()">;
def F8E4M3FNUZ : Type<CPred<"$_self.isFloat8E4M3FNUZ()">, "f8E4M3FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3FNUZType()">;
BuildableType<"$_builder.getType<Float8E4M3FNUZType>()">;
def F8E4M3B11FNUZ : Type<CPred<"$_self.isFloat8E4M3B11FNUZ()">, "f8E4M3B11FNUZ type">,
BuildableType<"$_builder.getFloat8E4M3B11FNUZType()">;
BuildableType<"$_builder.getType<Float8E4M3B11FNUZType>()">;
def F8E5M2FNUZ : Type<CPred<"$_self.isFloat8E5M2FNUZ()">, "f8E5M2FNUZ type">,
BuildableType<"$_builder.getFloat8E5M2FNUZType()">;
BuildableType<"$_builder.getType<Float8E5M2FNUZType>()">;
def F8E3M4 : Type<CPred<"$_self.isFloat8E3M4()">, "f8E3M4 type">,
BuildableType<"$_builder.getFloat8E3M4Type()">;
BuildableType<"$_builder.getType<Float8E3M4Type>()">;
def F4E2M1FN : Type<CPred<"$_self.isFloat4E2M1FN()">, "f4E2M1FN type">,
BuildableType<"$_builder.getFloat4E2M1FNType()">;
BuildableType<"$_builder.getType<Float4E2M1FNType>()">;
def F6E2M3FN : Type<CPred<"$_self.isFloat6E2M3FN()">, "f6E2M3FN type">,
BuildableType<"$_builder.getFloat6E2M3FNType()">;
BuildableType<"$_builder.getType<Float6E2M3FNType>()">;
def F6E3M2FN : Type<CPred<"$_self.isFloat6E3M2FN()">, "f6E3M2FN type">,
BuildableType<"$_builder.getFloat6E3M2FNType()">;
BuildableType<"$_builder.getType<Float6E3M2FNType>()">;
def F8E8M0FNU : Type<CPred<"$_self.isFloat8E8M0FNU()">, "f8E8M0FNU type">,
BuildableType<"$_builder.getFloat8E8M0FNUType()">;
BuildableType<"$_builder.getType<Float8E8M0FNUType>()">;

def AnyComplex : Type<CPred<"::llvm::isa<::mlir::ComplexType>($_self)">,
"complex-type", "::mlir::ComplexType">;
Expand Down
36 changes: 18 additions & 18 deletions mlir/lib/AsmParser/TypeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,58 +309,58 @@ Type Parser::parseNonFunctionType() {
// float-type
case Token::kw_f4E2M1FN:
consumeToken(Token::kw_f4E2M1FN);
return builder.getFloat4E2M1FNType();
return builder.getType<Float4E2M1FNType>();
case Token::kw_f6E2M3FN:
consumeToken(Token::kw_f6E2M3FN);
return builder.getFloat6E2M3FNType();
return builder.getType<Float6E2M3FNType>();
case Token::kw_f6E3M2FN:
consumeToken(Token::kw_f6E3M2FN);
return builder.getFloat6E3M2FNType();
return builder.getType<Float6E3M2FNType>();
case Token::kw_f8E5M2:
consumeToken(Token::kw_f8E5M2);
return builder.getFloat8E5M2Type();
return builder.getType<Float8E5M2Type>();
case Token::kw_f8E4M3:
consumeToken(Token::kw_f8E4M3);
return builder.getFloat8E4M3Type();
return builder.getType<Float8E4M3Type>();
case Token::kw_f8E4M3FN:
consumeToken(Token::kw_f8E4M3FN);
return builder.getFloat8E4M3FNType();
return builder.getType<Float8E4M3FNType>();
case Token::kw_f8E5M2FNUZ:
consumeToken(Token::kw_f8E5M2FNUZ);
return builder.getFloat8E5M2FNUZType();
return builder.getType<Float8E5M2FNUZType>();
case Token::kw_f8E4M3FNUZ:
consumeToken(Token::kw_f8E4M3FNUZ);
return builder.getFloat8E4M3FNUZType();
return builder.getType<Float8E4M3FNUZType>();
case Token::kw_f8E4M3B11FNUZ:
consumeToken(Token::kw_f8E4M3B11FNUZ);
return builder.getFloat8E4M3B11FNUZType();
return builder.getType<Float8E4M3B11FNUZType>();
case Token::kw_f8E3M4:
consumeToken(Token::kw_f8E3M4);
return builder.getFloat8E3M4Type();
return builder.getType<Float8E3M4Type>();
case Token::kw_f8E8M0FNU:
consumeToken(Token::kw_f8E8M0FNU);
return builder.getFloat8E8M0FNUType();
return builder.getType<Float8E8M0FNUType>();
case Token::kw_bf16:
consumeToken(Token::kw_bf16);
return builder.getBF16Type();
return builder.getType<BFloat16Type>();
case Token::kw_f16:
consumeToken(Token::kw_f16);
return builder.getF16Type();
return builder.getType<Float16Type>();
case Token::kw_tf32:
consumeToken(Token::kw_tf32);
return builder.getTF32Type();
return builder.getType<FloatTF32Type>();
case Token::kw_f32:
consumeToken(Token::kw_f32);
return builder.getF32Type();
return builder.getType<Float32Type>();
case Token::kw_f64:
consumeToken(Token::kw_f64);
return builder.getF64Type();
return builder.getType<Float64Type>();
case Token::kw_f80:
consumeToken(Token::kw_f80);
return builder.getF80Type();
return builder.getType<Float80Type>();
case Token::kw_f128:
consumeToken(Token::kw_f128);
return builder.getF128Type();
return builder.getType<Float128Type>();

// index-type
case Token::kw_index:
Expand Down
32 changes: 16 additions & 16 deletions mlir/lib/Dialect/Arith/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,22 +361,22 @@ Value createProduct(OpBuilder &builder, Location loc, ArrayRef<Value> values,
std::optional<FloatType> parseFloatType(MLIRContext *ctx, StringRef name) {
Builder b(ctx);
return llvm::StringSwitch<std::optional<FloatType>>(name)
.Case("f4E2M1FN", b.getFloat4E2M1FNType())
.Case("f6E2M3FN", b.getFloat6E2M3FNType())
.Case("f6E3M2FN", b.getFloat6E3M2FNType())
.Case("f8E5M2", b.getFloat8E5M2Type())
.Case("f8E4M3", b.getFloat8E4M3Type())
.Case("f8E4M3FN", b.getFloat8E4M3FNType())
.Case("f8E5M2FNUZ", b.getFloat8E5M2FNUZType())
.Case("f8E4M3FNUZ", b.getFloat8E4M3FNUZType())
.Case("f8E3M4", b.getFloat8E3M4Type())
.Case("f8E8M0FNU", b.getFloat8E8M0FNUType())
.Case("bf16", b.getBF16Type())
.Case("f16", b.getF16Type())
.Case("f32", b.getF32Type())
.Case("f64", b.getF64Type())
.Case("f80", b.getF80Type())
.Case("f128", b.getF128Type())
.Case("f4E2M1FN", b.getType<Float4E2M1FNType>())
.Case("f6E2M3FN", b.getType<Float6E2M3FNType>())
.Case("f6E3M2FN", b.getType<Float6E3M2FNType>())
.Case("f8E5M2", b.getType<Float8E5M2Type>())
.Case("f8E4M3", b.getType<Float8E4M3Type>())
.Case("f8E4M3FN", b.getType<Float8E4M3FNType>())
.Case("f8E5M2FNUZ", b.getType<Float8E5M2FNUZType>())
.Case("f8E4M3FNUZ", b.getType<Float8E4M3FNUZType>())
.Case("f8E3M4", b.getType<Float8E3M4Type>())
.Case("f8E8M0FNU", b.getType<Float8E8M0FNUType>())
.Case("bf16", b.getType<BFloat16Type>())
.Case("f16", b.getType<Float16Type>())
.Case("f32", b.getType<Float32Type>())
.Case("f64", b.getType<Float64Type>())
.Case("f80", b.getType<Float80Type>())
.Case("f128", b.getType<Float128Type>())
.Default(std::nullopt);
}

Expand Down
38 changes: 0 additions & 38 deletions mlir/lib/IR/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,44 +34,6 @@ Location Builder::getFusedLoc(ArrayRef<Location> locs, Attribute metadata) {
// Types.
//===----------------------------------------------------------------------===//

FloatType Builder::getFloat4E2M1FNType() {
return Float4E2M1FNType::get(context);
}

FloatType Builder::getFloat6E2M3FNType() {
return Float6E2M3FNType::get(context);
}

FloatType Builder::getFloat6E3M2FNType() {
return Float6E3M2FNType::get(context);
}

FloatType Builder::getFloat8E5M2Type() { return Float8E5M2Type::get(context); }

FloatType Builder::getFloat8E4M3Type() { return Float8E4M3Type::get(context); }

FloatType Builder::getFloat8E4M3FNType() {
return Float8E4M3FNType::get(context);
}

FloatType Builder::getFloat8E5M2FNUZType() {
return Float8E5M2FNUZType::get(context);
}

FloatType Builder::getFloat8E4M3FNUZType() {
return Float8E4M3FNUZType::get(context);
}

FloatType Builder::getFloat8E4M3B11FNUZType() {
return Float8E4M3B11FNUZType::get(context);
}

FloatType Builder::getFloat8E3M4Type() { return Float8E3M4Type::get(context); }

FloatType Builder::getFloat8E8M0FNUType() {
return Float8E8M0FNUType::get(context);
}

FloatType Builder::getBF16Type() { return BFloat16Type::get(context); }

FloatType Builder::getF16Type() { return Float16Type::get(context); }
Expand Down
55 changes: 0 additions & 55 deletions mlir/lib/IR/MLIRContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,17 +221,6 @@ class MLIRContextImpl {
llvm::DenseMap<StringRef, AbstractType *> nameToType;

/// Cached Type Instances.
Float4E2M1FNType f4E2M1FNTy;
Float6E2M3FNType f6E2M3FNTy;
Float6E3M2FNType f6E3M2FNTy;
Float8E5M2Type f8E5M2Ty;
Float8E4M3Type f8E4M3Ty;
Float8E4M3FNType f8E4M3FNTy;
Float8E5M2FNUZType f8E5M2FNUZTy;
Float8E4M3FNUZType f8E4M3FNUZTy;
Float8E4M3B11FNUZType f8E4M3B11FNUZTy;
Float8E3M4Type f8E3M4Ty;
Float8E8M0FNUType f8E8M0FNUTy;
BFloat16Type bf16Ty;
Float16Type f16Ty;
FloatTF32Type tf32Ty;
Expand Down Expand Up @@ -317,17 +306,6 @@ MLIRContext::MLIRContext(const DialectRegistry &registry, Threading setting)

//// Types.
/// Floating-point Types.
impl->f4E2M1FNTy = TypeUniquer::get<Float4E2M1FNType>(this);
impl->f6E2M3FNTy = TypeUniquer::get<Float6E2M3FNType>(this);
impl->f6E3M2FNTy = TypeUniquer::get<Float6E3M2FNType>(this);
impl->f8E5M2Ty = TypeUniquer::get<Float8E5M2Type>(this);
impl->f8E4M3Ty = TypeUniquer::get<Float8E4M3Type>(this);
impl->f8E4M3FNTy = TypeUniquer::get<Float8E4M3FNType>(this);
impl->f8E5M2FNUZTy = TypeUniquer::get<Float8E5M2FNUZType>(this);
impl->f8E4M3FNUZTy = TypeUniquer::get<Float8E4M3FNUZType>(this);
impl->f8E4M3B11FNUZTy = TypeUniquer::get<Float8E4M3B11FNUZType>(this);
impl->f8E3M4Ty = TypeUniquer::get<Float8E3M4Type>(this);
impl->f8E8M0FNUTy = TypeUniquer::get<Float8E8M0FNUType>(this);
impl->bf16Ty = TypeUniquer::get<BFloat16Type>(this);
impl->f16Ty = TypeUniquer::get<Float16Type>(this);
impl->tf32Ty = TypeUniquer::get<FloatTF32Type>(this);
Expand Down Expand Up @@ -1044,39 +1022,6 @@ AbstractType::lookup(StringRef name, MLIRContext *context) {
/// This should not be used directly.
StorageUniquer &MLIRContext::getTypeUniquer() { return getImpl().typeUniquer; }

Float4E2M1FNType Float4E2M1FNType::get(MLIRContext *context) {
return context->getImpl().f4E2M1FNTy;
}
Float6E2M3FNType Float6E2M3FNType::get(MLIRContext *context) {
return context->getImpl().f6E2M3FNTy;
}
Float6E3M2FNType Float6E3M2FNType::get(MLIRContext *context) {
return context->getImpl().f6E3M2FNTy;
}
Float8E5M2Type Float8E5M2Type::get(MLIRContext *context) {
return context->getImpl().f8E5M2Ty;
}
Float8E4M3Type Float8E4M3Type::get(MLIRContext *context) {
return context->getImpl().f8E4M3Ty;
}
Float8E4M3FNType Float8E4M3FNType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNTy;
}
Float8E5M2FNUZType Float8E5M2FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E5M2FNUZTy;
}
Float8E4M3FNUZType Float8E4M3FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3FNUZTy;
}
Float8E4M3B11FNUZType Float8E4M3B11FNUZType::get(MLIRContext *context) {
return context->getImpl().f8E4M3B11FNUZTy;
}
Float8E3M4Type Float8E3M4Type::get(MLIRContext *context) {
return context->getImpl().f8E3M4Ty;
}
Float8E8M0FNUType Float8E8M0FNUType::get(MLIRContext *context) {
return context->getImpl().f8E8M0FNUTy;
}
BFloat16Type BFloat16Type::get(MLIRContext *context) {
return context->getImpl().bf16Ty;
}
Expand Down