Skip to content

Commit 0d906a4

Browse files
[mlir][Transforms] Dialect conversion: add originalType param to materializations (#112128)
This commit adds an optional `originalType` parameter to target materialization functions. Without this parameter, target materializations are underspecified. Note: `originalType` is only needed for target materializations. Source/argument materializations do not have it. Consider the following example: Let's assume that a conversion pattern "P1" replaced an SSA value "v1" (type "t1") with "v2" (type "t2"). Then a different conversion pattern "P2" matches an op that has "v1" as an operand. Let's furthermore assume that "P2" determines that the legalized type of "t1" is "t3", which may be different from "t2". In this example, the target materialization callback will be invoked with: outputType = "t3", inputs = "v2", originalType = "t1". Note that the original type "t1" cannot be recovered from just "t3" and "v2"; that's why the `originalType` parameter is added. This change is in preparation of merging the 1:1 and 1:N dialect conversion drivers. As part of that change, argument materializations will be removed (as they are no longer needed; they were just a workaround because of missing 1:N support in the dialect conversion). The new `originalType` parameter is needed when lowering MemRef to LLVM. During that lowering, MemRef function block arguments are replaced with the elements that make up a MemRef descriptor. The type converter is set up in such a way that the legalized type of a MemRef type is an `!llvm.struct` that represents the MemRef descriptor. When the bare pointer calling convention is enabled, the function block arguments consist of just an LLVM pointer. In such a case, a target materialization will be invoked to construct a MemRef descriptor (output type = `!llvm.struct<...>`) from just the bare pointer (inputs = `!llvm.ptr`). The original MemRef type is required to construct the MemRef descriptor, as static sizes/strides/offset cannot be inferred from just the bare pointer.
1 parent 1ae9fe5 commit 0d906a4

File tree

2 files changed

+136
-55
lines changed

2 files changed

+136
-55
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

+72-36
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ class TypeConverter {
138138
};
139139

140140
/// Register a conversion function. A conversion function must be convertible
141-
/// to any of the following forms(where `T` is a class derived from `Type`:
141+
/// to any of the following forms (where `T` is a class derived from `Type`):
142+
///
142143
/// * std::optional<Type>(T)
143144
/// - This form represents a 1-1 type conversion. It should return nullptr
144145
/// or `std::nullopt` to signify failure. If `std::nullopt` is returned,
@@ -151,15 +152,7 @@ class TypeConverter {
151152
/// existing value are expected to be removed during conversion. If
152153
/// `std::nullopt` is returned, the converter is allowed to try another
153154
/// conversion function to perform the conversion.
154-
/// * std::optional<LogicalResult>(T, SmallVectorImpl<Type> &,
155-
/// ArrayRef<Type>)
156-
/// - This form represents a 1-N type conversion supporting recursive
157-
/// types. The first two arguments and the return value are the same as
158-
/// for the regular 1-N form. The third argument is contains is the
159-
/// "call stack" of the recursive conversion: it contains the list of
160-
/// types currently being converted, with the current type being the
161-
/// last one. If it is present more than once in the list, the
162-
/// conversion concerns a recursive type.
155+
///
163156
/// Note: When attempting to convert a type, e.g. via 'convertType', the
164157
/// mostly recently added conversions will be invoked first.
165158
template <typename FnT, typename T = typename llvm::function_traits<
@@ -178,6 +171,9 @@ class TypeConverter {
178171
/// it failed but other materialization can be attempted, and `nullptr` on
179172
/// unrecoverable failure. Materialization functions must be provided when a
180173
/// type conversion may persist after the conversion has finished.
174+
///
175+
/// Note: Target materializations may optionally accept an additional Type
176+
/// parameter, which is the original type of the SSA value.
181177

182178
/// This method registers a materialization that will be called when
183179
/// converting (potentially multiple) block arguments that were the result of
@@ -203,11 +199,22 @@ class TypeConverter {
203199

204200
/// This method registers a materialization that will be called when
205201
/// converting an illegal (source) value to a legal (target) type.
202+
///
203+
/// Note: For target materializations, users can optionally take the original
204+
/// type. This type may be different from the type of the input. For example,
205+
/// let's assume that a conversion pattern "P1" replaced an SSA value "v1"
206+
/// (type "t1") with "v2" (type "t2"). Then a different conversion pattern
207+
/// "P2" matches an op that has "v1" as an operand. Let's furthermore assume
208+
/// that "P2" determines that the legalized type of "t1" is "t3", which may
209+
/// be different from "t2". In this example, the target materialization
210+
/// will be invoked with: outputType = "t3", inputs = "v2",
211+
// originalType = "t1". Note that the original type "t1" cannot be recovered
212+
/// from just "t3" and "v2"; that's why the originalType parameter exists.
206213
template <typename FnT, typename T = typename llvm::function_traits<
207214
std::decay_t<FnT>>::template arg_t<1>>
208215
void addTargetMaterialization(FnT &&callback) {
209216
targetMaterializations.emplace_back(
210-
wrapMaterialization<T>(std::forward<FnT>(callback)));
217+
wrapTargetMaterialization<T>(std::forward<FnT>(callback)));
211218
}
212219

213220
/// Register a conversion function for attributes within types. Type
@@ -303,21 +310,12 @@ class TypeConverter {
303310
/// `add*Materialization` for more information on the context for these
304311
/// methods.
305312
Value materializeArgumentConversion(OpBuilder &builder, Location loc,
306-
Type resultType,
307-
ValueRange inputs) const {
308-
return materializeConversion(argumentMaterializations, builder, loc,
309-
resultType, inputs);
310-
}
313+
Type resultType, ValueRange inputs) const;
311314
Value materializeSourceConversion(OpBuilder &builder, Location loc,
312-
Type resultType, ValueRange inputs) const {
313-
return materializeConversion(sourceMaterializations, builder, loc,
314-
resultType, inputs);
315-
}
315+
Type resultType, ValueRange inputs) const;
316316
Value materializeTargetConversion(OpBuilder &builder, Location loc,
317-
Type resultType, ValueRange inputs) const {
318-
return materializeConversion(targetMaterializations, builder, loc,
319-
resultType, inputs);
320-
}
317+
Type resultType, ValueRange inputs,
318+
Type originalType = {}) const;
321319

322320
/// Convert an attribute present `attr` from within the type `type` using
323321
/// the registered conversion functions. If no applicable conversion has been
@@ -333,21 +331,23 @@ class TypeConverter {
333331
using ConversionCallbackFn = std::function<std::optional<LogicalResult>(
334332
Type, SmallVectorImpl<Type> &)>;
335333

336-
/// The signature of the callback used to materialize a conversion.
334+
/// The signature of the callback used to materialize a source/argument
335+
/// conversion.
336+
///
337+
/// Arguments: builder, result type, inputs, location
337338
using MaterializationCallbackFn = std::function<std::optional<Value>(
338339
OpBuilder &, Type, ValueRange, Location)>;
339340

341+
/// The signature of the callback used to materialize a target conversion.
342+
///
343+
/// Arguments: builder, result type, inputs, location, original type
344+
using TargetMaterializationCallbackFn = std::function<std::optional<Value>(
345+
OpBuilder &, Type, ValueRange, Location, Type)>;
346+
340347
/// The signature of the callback used to convert a type attribute.
341348
using TypeAttributeConversionCallbackFn =
342349
std::function<AttributeConversionResult(Type, Attribute)>;
343350

344-
/// Attempt to materialize a conversion using one of the provided
345-
/// materialization functions.
346-
Value
347-
materializeConversion(ArrayRef<MaterializationCallbackFn> materializations,
348-
OpBuilder &builder, Location loc, Type resultType,
349-
ValueRange inputs) const;
350-
351351
/// Generate a wrapper for the given callback. This allows for accepting
352352
/// different callback forms, that all compose into a single version.
353353
/// With callback of form: `std::optional<Type>(T)`
@@ -388,9 +388,10 @@ class TypeConverter {
388388
cachedMultiConversions.clear();
389389
}
390390

391-
/// Generate a wrapper for the given materialization callback. The callback
392-
/// may take any subclass of `Type` and the wrapper will check for the target
393-
/// type to be of the expected class before calling the callback.
391+
/// Generate a wrapper for the given argument/source materialization
392+
/// callback. The callback may take any subclass of `Type` and the
393+
/// wrapper will check for the target type to be of the expected class
394+
/// before calling the callback.
394395
template <typename T, typename FnT>
395396
MaterializationCallbackFn wrapMaterialization(FnT &&callback) const {
396397
return [callback = std::forward<FnT>(callback)](
@@ -402,6 +403,41 @@ class TypeConverter {
402403
};
403404
}
404405

406+
/// Generate a wrapper for the given target materialization callback.
407+
/// The callback may take any subclass of `Type` and the wrapper will check
408+
/// for the target type to be of the expected class before calling the
409+
/// callback.
410+
///
411+
/// With callback of form:
412+
/// `Value(OpBuilder &, T, ValueRange, Location, Type)`
413+
template <typename T, typename FnT>
414+
std::enable_if_t<
415+
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location, Type>,
416+
TargetMaterializationCallbackFn>
417+
wrapTargetMaterialization(FnT &&callback) const {
418+
return [callback = std::forward<FnT>(callback)](
419+
OpBuilder &builder, Type resultType, ValueRange inputs,
420+
Location loc, Type originalType) -> std::optional<Value> {
421+
if (T derivedType = dyn_cast<T>(resultType))
422+
return callback(builder, derivedType, inputs, loc, originalType);
423+
return std::nullopt;
424+
};
425+
}
426+
/// With callback of form:
427+
/// `Value(OpBuilder &, T, ValueRange, Location)`
428+
template <typename T, typename FnT>
429+
std::enable_if_t<
430+
std::is_invocable_v<FnT, OpBuilder &, T, ValueRange, Location>,
431+
TargetMaterializationCallbackFn>
432+
wrapTargetMaterialization(FnT &&callback) const {
433+
return wrapTargetMaterialization<T>(
434+
[callback = std::forward<FnT>(callback)](
435+
OpBuilder &builder, T resultType, ValueRange inputs, Location loc,
436+
Type originalType) -> std::optional<Value> {
437+
return callback(builder, resultType, inputs, loc);
438+
});
439+
}
440+
405441
/// Generate a wrapper for the given memory space conversion callback. The
406442
/// callback may take any subclass of `Attribute` and the wrapper will check
407443
/// for the target attribute to be of the expected class before calling the
@@ -434,7 +470,7 @@ class TypeConverter {
434470
/// The list of registered materialization functions.
435471
SmallVector<MaterializationCallbackFn, 2> argumentMaterializations;
436472
SmallVector<MaterializationCallbackFn, 2> sourceMaterializations;
437-
SmallVector<MaterializationCallbackFn, 2> targetMaterializations;
473+
SmallVector<TargetMaterializationCallbackFn, 2> targetMaterializations;
438474

439475
/// The list of registered type attribute conversion functions.
440476
SmallVector<TypeAttributeConversionCallbackFn, 2> typeAttributeConversions;

mlir/lib/Transforms/Utils/DialectConversion.cpp

+64-19
Original file line numberDiff line numberDiff line change
@@ -683,10 +683,10 @@ enum MaterializationKind {
683683
/// conversion.
684684
class UnresolvedMaterializationRewrite : public OperationRewrite {
685685
public:
686-
UnresolvedMaterializationRewrite(
687-
ConversionPatternRewriterImpl &rewriterImpl,
688-
UnrealizedConversionCastOp op, const TypeConverter *converter = nullptr,
689-
MaterializationKind kind = MaterializationKind::Target);
686+
UnresolvedMaterializationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
687+
UnrealizedConversionCastOp op,
688+
const TypeConverter *converter,
689+
MaterializationKind kind, Type originalType);
690690

691691
static bool classof(const IRRewrite *rewrite) {
692692
return rewrite->getKind() == Kind::UnresolvedMaterialization;
@@ -708,11 +708,18 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
708708
return converterAndKind.getInt();
709709
}
710710

711+
/// Return the original type of the SSA value.
712+
Type getOriginalType() const { return originalType; }
713+
711714
private:
712715
/// The corresponding type converter to use when resolving this
713716
/// materialization, and the kind of this materialization.
714717
llvm::PointerIntPair<const TypeConverter *, 2, MaterializationKind>
715718
converterAndKind;
719+
720+
/// The original type of the SSA value. Only used for target
721+
/// materializations.
722+
Type originalType;
716723
};
717724
} // namespace
718725

@@ -808,6 +815,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
808815
Value buildUnresolvedMaterialization(MaterializationKind kind,
809816
OpBuilder::InsertPoint ip, Location loc,
810817
ValueRange inputs, Type outputType,
818+
Type originalType,
811819
const TypeConverter *converter);
812820

813821
//===--------------------------------------------------------------------===//
@@ -1034,9 +1042,12 @@ void CreateOperationRewrite::rollback() {
10341042

10351043
UnresolvedMaterializationRewrite::UnresolvedMaterializationRewrite(
10361044
ConversionPatternRewriterImpl &rewriterImpl, UnrealizedConversionCastOp op,
1037-
const TypeConverter *converter, MaterializationKind kind)
1045+
const TypeConverter *converter, MaterializationKind kind, Type originalType)
10381046
: OperationRewrite(Kind::UnresolvedMaterialization, rewriterImpl, op),
1039-
converterAndKind(converter, kind) {
1047+
converterAndKind(converter, kind), originalType(originalType) {
1048+
assert(!originalType ||
1049+
kind == MaterializationKind::Target &&
1050+
"original type is valid only for target materializations");
10401051
rewriterImpl.unresolvedMaterializations[op] = this;
10411052
}
10421053

@@ -1139,7 +1150,7 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
11391150
Value castValue = buildUnresolvedMaterialization(
11401151
MaterializationKind::Target, computeInsertPoint(newOperand),
11411152
operandLoc, /*inputs=*/newOperand, /*outputType=*/desiredType,
1142-
currentTypeConverter);
1153+
/*originalType=*/origType, currentTypeConverter);
11431154
mapping.map(newOperand, castValue);
11441155
newOperand = castValue;
11451156
}
@@ -1255,7 +1266,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12551266
MaterializationKind::Source,
12561267
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
12571268
/*inputs=*/ValueRange(),
1258-
/*outputType=*/origArgType, converter);
1269+
/*outputType=*/origArgType, /*originalType=*/Type(), converter);
12591270
mapping.map(origArg, repl);
12601271
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
12611272
continue;
@@ -1280,7 +1291,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12801291
Value argMat = buildUnresolvedMaterialization(
12811292
MaterializationKind::Argument,
12821293
OpBuilder::InsertPoint(newBlock, newBlock->begin()), origArg.getLoc(),
1283-
/*inputs=*/replArgs, origArgType, converter);
1294+
/*inputs=*/replArgs, /*outputType=*/origArgType,
1295+
/*originalType=*/Type(), converter);
12841296
mapping.map(origArg, argMat);
12851297

12861298
Type legalOutputType;
@@ -1299,7 +1311,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
12991311
if (legalOutputType && legalOutputType != origArgType) {
13001312
Value targetMat = buildUnresolvedMaterialization(
13011313
MaterializationKind::Target, computeInsertPoint(argMat),
1302-
origArg.getLoc(), argMat, legalOutputType, converter);
1314+
origArg.getLoc(), /*inputs=*/argMat, /*outputType=*/legalOutputType,
1315+
/*originalType=*/origArgType, converter);
13031316
mapping.map(argMat, targetMat);
13041317
}
13051318
appendRewrite<ReplaceBlockArgRewrite>(block, origArg);
@@ -1322,7 +1335,12 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
13221335
/// of input operands.
13231336
Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13241337
MaterializationKind kind, OpBuilder::InsertPoint ip, Location loc,
1325-
ValueRange inputs, Type outputType, const TypeConverter *converter) {
1338+
ValueRange inputs, Type outputType, Type originalType,
1339+
const TypeConverter *converter) {
1340+
assert(!originalType ||
1341+
kind == MaterializationKind::Target &&
1342+
"original type is valid only for target materializations");
1343+
13261344
// Avoid materializing an unnecessary cast.
13271345
if (inputs.size() == 1 && inputs.front().getType() == outputType)
13281346
return inputs.front();
@@ -1333,7 +1351,8 @@ Value ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
13331351
builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
13341352
auto convertOp =
13351353
builder.create<UnrealizedConversionCastOp>(loc, outputType, inputs);
1336-
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind);
1354+
appendRewrite<UnresolvedMaterializationRewrite>(convertOp, converter, kind,
1355+
originalType);
13371356
return convertOp.getResult(0);
13381357
}
13391358

@@ -1381,7 +1400,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
13811400
newValue = buildUnresolvedMaterialization(
13821401
MaterializationKind::Source, computeInsertPoint(result),
13831402
result.getLoc(), /*inputs=*/ValueRange(),
1384-
/*outputType=*/result.getType(), currentTypeConverter);
1403+
/*outputType=*/result.getType(), /*originalType=*/Type(),
1404+
currentTypeConverter);
13851405
}
13861406

13871407
// Remap, and check for any result type changes.
@@ -2408,7 +2428,8 @@ legalizeUnresolvedMaterialization(RewriterBase &rewriter,
24082428
[[fallthrough]];
24092429
case MaterializationKind::Target:
24102430
newMaterialization = converter->materializeTargetConversion(
2411-
rewriter, op->getLoc(), outputType, inputOperands);
2431+
rewriter, op->getLoc(), outputType, inputOperands,
2432+
rewrite->getOriginalType());
24122433
break;
24132434
case MaterializationKind::Source:
24142435
newMaterialization = converter->materializeSourceConversion(
@@ -2565,7 +2586,7 @@ void OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
25652586
MaterializationKind::Source, computeInsertPoint(newValue),
25662587
originalValue.getLoc(),
25672588
/*inputs=*/newValue, /*outputType=*/originalValue.getType(),
2568-
converter);
2589+
/*originalType=*/Type(), converter);
25692590
rewriterImpl.mapping.map(originalValue, castValue);
25702591
inverseMapping[castValue].push_back(originalValue);
25712592
llvm::erase(inverseMapping[newValue], originalValue);
@@ -2787,15 +2808,39 @@ TypeConverter::convertSignatureArgs(TypeRange types,
27872808
return success();
27882809
}
27892810

2790-
Value TypeConverter::materializeConversion(
2791-
ArrayRef<MaterializationCallbackFn> materializations, OpBuilder &builder,
2792-
Location loc, Type resultType, ValueRange inputs) const {
2793-
for (const MaterializationCallbackFn &fn : llvm::reverse(materializations))
2811+
Value TypeConverter::materializeArgumentConversion(OpBuilder &builder,
2812+
Location loc,
2813+
Type resultType,
2814+
ValueRange inputs) const {
2815+
for (const MaterializationCallbackFn &fn :
2816+
llvm::reverse(argumentMaterializations))
2817+
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
2818+
return *result;
2819+
return nullptr;
2820+
}
2821+
2822+
Value TypeConverter::materializeSourceConversion(OpBuilder &builder,
2823+
Location loc, Type resultType,
2824+
ValueRange inputs) const {
2825+
for (const MaterializationCallbackFn &fn :
2826+
llvm::reverse(sourceMaterializations))
27942827
if (std::optional<Value> result = fn(builder, resultType, inputs, loc))
27952828
return *result;
27962829
return nullptr;
27972830
}
27982831

2832+
Value TypeConverter::materializeTargetConversion(OpBuilder &builder,
2833+
Location loc, Type resultType,
2834+
ValueRange inputs,
2835+
Type originalType) const {
2836+
for (const TargetMaterializationCallbackFn &fn :
2837+
llvm::reverse(targetMaterializations))
2838+
if (std::optional<Value> result =
2839+
fn(builder, resultType, inputs, loc, originalType))
2840+
return *result;
2841+
return nullptr;
2842+
}
2843+
27992844
std::optional<TypeConverter::SignatureConversion>
28002845
TypeConverter::convertBlockSignature(Block *block) const {
28012846
SignatureConversion conversion(block->getNumArguments());

0 commit comments

Comments
 (0)