Skip to content

Commit 8bc90eb

Browse files
committed
further comments addressed
1 parent 4a5f0db commit 8bc90eb

File tree

1 file changed

+16
-22
lines changed

1 file changed

+16
-22
lines changed

compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -454,31 +454,25 @@ struct SerializeSplatTransferReadWithTargetLoadSize
454454

455455
// Find the arith.constant that the vector operand is a view of, if it is
456456
// one.
457-
arith::ConstantOp constantVectorSource = [&writeOp]() -> arith::ConstantOp {
458-
Value current = writeOp.getVector();
459-
while (Operation *op = current.getDefiningOp()) {
460-
if (auto cOp = dyn_cast<arith::ConstantOp>(op)) return cOp;
461-
if (op->getNumOperands() != 1) return {};
462-
current = op->getOperand(0);
463-
}
464-
return {};
465-
}();
466-
if (!constantVectorSource) {
457+
Value currentTraversalValue = writeOp.getVector();
458+
arith::ConstantOp vectorSource;
459+
while (Operation *op = currentTraversalValue.getDefiningOp()) {
460+
if (auto cOp = dyn_cast<arith::ConstantOp>(op)) vectorSource = cOp;
461+
if (vectorSource || op->getNumOperands() != 1) break;
462+
currentTraversalValue = op->getOperand(0);
463+
}
464+
if (!vectorSource) {
467465
return rewriter.notifyMatchFailure(
468466
writeOp, "vector isn't derived from arith.constant");
469467
}
470468

471469
// Get the splat value of the constant vector.
472-
auto maybeSplat = [&]() -> FailureOr<Attribute> {
473-
TypedAttr constantValue = constantVectorSource.getValue();
474-
auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
475-
if (!splatAttr || !splatAttr.isSplat()) {
476-
return rewriter.notifyMatchFailure(writeOp, "constant isn't a splat");
477-
}
478-
return splatAttr.getSplatValue<Attribute>();
479-
}();
480-
if (failed(maybeSplat)) return failure();
481-
Attribute splat = maybeSplat.value();
470+
TypedAttr constantValue = vectorSource.getValue();
471+
auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
472+
if (!splatAttr || !splatAttr.isSplat()) {
473+
return rewriter.notifyMatchFailure(writeOp, "constant isn't a splat");
474+
}
475+
Attribute splat = splatAttr.getSplatValue<Attribute>();
482476

483477
int64_t bytesPerWrite = deviceModel.getPreferredLoadBytes();
484478

@@ -504,8 +498,8 @@ struct SerializeSplatTransferReadWithTargetLoadSize
504498
auto createTransferWrite = [&](uint32_t n, Value offset) {
505499
VectorType type = VectorType::get({n}, elementType);
506500
DenseElementsAttr attr = DenseElementsAttr::get(type, splat);
507-
auto newConstantOp = rewriter.create<arith::ConstantOp>(
508-
constantVectorSource.getLoc(), type, attr);
501+
auto newConstantOp =
502+
rewriter.create<arith::ConstantOp>(vectorSource.getLoc(), type, attr);
509503
rewriter.create<vector::TransferWriteOp>(
510504
writeOp.getLoc(), newConstantOp.getResult(), writeDestination, offset,
511505
writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());

0 commit comments

Comments
 (0)