Skip to content

Commit

Permalink
further comments addressed
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Feb 21, 2025
1 parent 4a5f0db commit 8bc90eb
Showing 1 changed file with 16 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -454,31 +454,25 @@ struct SerializeSplatTransferReadWithTargetLoadSize

// Find the arith.constant that the vector operand is a view of, if it is
// one.
arith::ConstantOp constantVectorSource = [&writeOp]() -> arith::ConstantOp {
Value current = writeOp.getVector();
while (Operation *op = current.getDefiningOp()) {
if (auto cOp = dyn_cast<arith::ConstantOp>(op)) return cOp;
if (op->getNumOperands() != 1) return {};
current = op->getOperand(0);
}
return {};
}();
if (!constantVectorSource) {
Value currentTraversalValue = writeOp.getVector();
arith::ConstantOp vectorSource;
while (Operation *op = currentTraversalValue.getDefiningOp()) {
if (auto cOp = dyn_cast<arith::ConstantOp>(op)) vectorSource = cOp;
if (vectorSource || op->getNumOperands() != 1) break;
currentTraversalValue = op->getOperand(0);
}
if (!vectorSource) {
return rewriter.notifyMatchFailure(
writeOp, "vector isn't derived from arith.constant");
}

// Get the splat value of the constant vector.
auto maybeSplat = [&]() -> FailureOr<Attribute> {
TypedAttr constantValue = constantVectorSource.getValue();
auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
if (!splatAttr || !splatAttr.isSplat()) {
return rewriter.notifyMatchFailure(writeOp, "constant isn't a splat");
}
return splatAttr.getSplatValue<Attribute>();
}();
if (failed(maybeSplat)) return failure();
Attribute splat = maybeSplat.value();
TypedAttr constantValue = vectorSource.getValue();
auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
if (!splatAttr || !splatAttr.isSplat()) {
return rewriter.notifyMatchFailure(writeOp, "constant isn't a splat");
}
Attribute splat = splatAttr.getSplatValue<Attribute>();

int64_t bytesPerWrite = deviceModel.getPreferredLoadBytes();

Expand All @@ -504,8 +498,8 @@ struct SerializeSplatTransferReadWithTargetLoadSize
auto createTransferWrite = [&](uint32_t n, Value offset) {
VectorType type = VectorType::get({n}, elementType);
DenseElementsAttr attr = DenseElementsAttr::get(type, splat);
auto newConstantOp = rewriter.create<arith::ConstantOp>(
constantVectorSource.getLoc(), type, attr);
auto newConstantOp =
rewriter.create<arith::ConstantOp>(vectorSource.getLoc(), type, attr);
rewriter.create<vector::TransferWriteOp>(
writeOp.getLoc(), newConstantOp.getResult(), writeDestination, offset,
writeOp.getPermutationMapAttr(), writeOp.getInBoundsAttr());
Expand Down

0 comments on commit 8bc90eb

Please sign in to comment.