Skip to content

Commit 1e1bf79

Browse files
authored
[mlir][emitc] Add an option to cast array type to ptr type (#126385)
1 parent ff99af7 commit 1e1bf79

File tree

4 files changed

+54
-10
lines changed

4 files changed

+54
-10
lines changed

mlir/include/mlir/Dialect/EmitC/IR/EmitC.td

+1-2
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> {
313313

314314
def EmitC_CastOp : EmitC_Op<"cast",
315315
[CExpression,
316-
DeclareOpInterfaceMethods<CastOpInterface>,
317-
SameOperandsAndResultShape]> {
316+
DeclareOpInterfaceMethods<CastOpInterface>]> {
318317
let summary = "Cast operation";
319318
let description = [{
320319
The `emitc.cast` operation performs an explicit type conversion and is emitted

mlir/lib/Dialect/EmitC/IR/EmitC.cpp

+13-5
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,14 @@ LogicalResult emitc::AssignOp::verify() {
305305
bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
306306
Type input = inputs.front(), output = outputs.front();
307307

308+
if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
309+
if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
310+
return (arrayType.getElementType() == pointerType.getPointee()) &&
311+
arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
312+
}
313+
return false;
314+
}
315+
308316
return (
309317
(emitc::isIntegerIndexOrOpaqueType(input) ||
310318
emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
@@ -757,9 +765,9 @@ void IfOp::print(OpAsmPrinter &p) {
757765

758766
/// Given the region at `index`, or the parent operation if `index` is None,
759767
/// return the successor regions. These are the regions that may be selected
760-
/// during the flow of control. `operands` is a set of optional attributes that
761-
/// correspond to a constant value for each operand, or null if that operand is
762-
/// not a constant.
768+
/// during the flow of control. `operands` is a set of optional attributes
769+
/// that correspond to a constant value for each operand, or null if that
770+
/// operand is not a constant.
763771
void IfOp::getSuccessorRegions(RegionBranchPoint point,
764772
SmallVectorImpl<RegionSuccessor> &regions) {
765773
// The `then` and the `else` region branch back to the parent operation.
@@ -1086,8 +1094,8 @@ emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
10861094
LogicalResult mlir::emitc::LValueType::verify(
10871095
llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
10881096
mlir::Type value) {
1089-
// Check that the wrapped type is valid. This especially forbids nested lvalue
1090-
// types.
1097+
// Check that the wrapped type is valid. This especially forbids nested
1098+
// lvalue types.
10911099
if (!isSupportedEmitCType(value))
10921100
return emitError()
10931101
<< "!emitc.lvalue must wrap supported emitc type, but got " << value;

mlir/test/Dialect/EmitC/invalid_ops.mlir

+35-3
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,41 @@ func.func @cast_tensor(%arg : tensor<f32>) {
130130

131131
// -----
132132

133-
func.func @cast_array(%arg : !emitc.array<4xf32>) {
134-
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<4xf32>' and result type '!emitc.array<4xf32>' are cast incompatible}}
135-
%1 = emitc.cast %arg: !emitc.array<4xf32> to !emitc.array<4xf32>
133+
func.func @cast_to_array(%arg : f32) {
134+
// expected-error @+1 {{'emitc.cast' op operand type 'f32' and result type '!emitc.array<4xf32>' are cast incompatible}}
135+
%1 = emitc.cast %arg: f32 to !emitc.array<4xf32>
136+
return
137+
}
138+
139+
// -----
140+
141+
func.func @cast_multidimensional_array(%arg : !emitc.array<1x2xi32>) {
142+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<1x2xi32>' and result type '!emitc.ptr<i32>' are cast incompatible}}
143+
%1 = emitc.cast %arg: !emitc.array<1x2xi32> to !emitc.ptr<i32>
144+
return
145+
}
146+
147+
// -----
148+
149+
func.func @cast_array_zero_rank(%arg : !emitc.array<0xi32>) {
150+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<0xi32>' and result type '!emitc.ptr<i32>' are cast incompatible}}
151+
%1 = emitc.cast %arg: !emitc.array<0xi32> to !emitc.ptr<i32>
152+
return
153+
}
154+
155+
// -----
156+
157+
func.func @cast_array_to_pointer_types_mismatch(%arg : !emitc.array<3xi32>) {
158+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.array<3xi32>' and result type '!emitc.ptr<f16>' are cast incompatible}}
159+
%1 = emitc.cast %arg: !emitc.array<3xi32> to !emitc.ptr<f16>
160+
return
161+
}
162+
163+
// -----
164+
165+
func.func @cast_pointer_to_array(%arg : !emitc.ptr<i32>) {
166+
// expected-error @+1 {{'emitc.cast' op operand type '!emitc.ptr<i32>' and result type '!emitc.array<3xi32>' are cast incompatible}}
167+
%1 = emitc.cast %arg: !emitc.ptr<i32> to !emitc.array<3xi32>
136168
return
137169
}
138170

mlir/test/Dialect/EmitC/ops.mlir

+5
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,11 @@ func.func @cast(%arg0: i32) {
3939
return
4040
}
4141

42+
func.func @cast_array_to_pointer(%arg0: !emitc.array<3xi32>) {
43+
%1 = emitc.cast %arg0: !emitc.array<3xi32> to !emitc.ptr<i32>
44+
return
45+
}
46+
4247
func.func @c() {
4348
%1 = "emitc.constant"(){value = 42 : i32} : () -> i32
4449
%2 = "emitc.constant"(){value = 42 : index} : () -> !emitc.size_t

0 commit comments

Comments
 (0)