@@ -454,31 +454,25 @@ struct SerializeSplatTransferReadWithTargetLoadSize
454
454
455
455
// Find the arith.constant that the vector operand is a view of, if it is
456
456
// one.
457
- arith::ConstantOp constantVectorSource = [&writeOp]() -> arith::ConstantOp {
458
- Value current = writeOp.getVector ();
459
- while (Operation *op = current.getDefiningOp ()) {
460
- if (auto cOp = dyn_cast<arith::ConstantOp>(op)) return cOp;
461
- if (op->getNumOperands () != 1 ) return {};
462
- current = op->getOperand (0 );
463
- }
464
- return {};
465
- }();
466
- if (!constantVectorSource) {
457
+ Value currentTraversalValue = writeOp.getVector ();
458
+ arith::ConstantOp vectorSource;
459
+ while (Operation *op = currentTraversalValue.getDefiningOp ()) {
460
+ if (auto cOp = dyn_cast<arith::ConstantOp>(op)) vectorSource = cOp;
461
+ if (vectorSource || op->getNumOperands () != 1 ) break ;
462
+ currentTraversalValue = op->getOperand (0 );
463
+ }
464
+ if (!vectorSource) {
467
465
return rewriter.notifyMatchFailure (
468
466
writeOp, " vector isn't derived from arith.constant" );
469
467
}
470
468
471
469
// Get the splat value of the constant vector.
472
- auto maybeSplat = [&]() -> FailureOr<Attribute> {
473
- TypedAttr constantValue = constantVectorSource.getValue ();
474
- auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
475
- if (!splatAttr || !splatAttr.isSplat ()) {
476
- return rewriter.notifyMatchFailure (writeOp, " constant isn't a splat" );
477
- }
478
- return splatAttr.getSplatValue <Attribute>();
479
- }();
480
- if (failed (maybeSplat)) return failure ();
481
- Attribute splat = maybeSplat.value ();
470
+ TypedAttr constantValue = vectorSource.getValue ();
471
+ auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
472
+ if (!splatAttr || !splatAttr.isSplat ()) {
473
+ return rewriter.notifyMatchFailure (writeOp, " constant isn't a splat" );
474
+ }
475
+ Attribute splat = splatAttr.getSplatValue <Attribute>();
482
476
483
477
int64_t bytesPerWrite = deviceModel.getPreferredLoadBytes ();
484
478
@@ -504,8 +498,8 @@ struct SerializeSplatTransferReadWithTargetLoadSize
504
498
auto createTransferWrite = [&](uint32_t n, Value offset) {
505
499
VectorType type = VectorType::get ({n}, elementType);
506
500
DenseElementsAttr attr = DenseElementsAttr::get (type, splat);
507
- auto newConstantOp = rewriter. create <arith::ConstantOp>(
508
- constantVectorSource .getLoc (), type, attr);
501
+ auto newConstantOp =
502
+ rewriter. create <arith::ConstantOp>(vectorSource .getLoc (), type, attr);
509
503
rewriter.create <vector::TransferWriteOp>(
510
504
writeOp.getLoc (), newConstantOp.getResult (), writeDestination, offset,
511
505
writeOp.getPermutationMapAttr (), writeOp.getInBoundsAttr ());
0 commit comments