@@ -454,31 +454,25 @@ struct SerializeSplatTransferReadWithTargetLoadSize
454454
455455 // Find the arith.constant that the vector operand is a view of, if it is
456456 // one.
457- arith::ConstantOp constantVectorSource = [&writeOp]() -> arith::ConstantOp {
458- Value current = writeOp.getVector ();
459- while (Operation *op = current.getDefiningOp ()) {
460- if (auto cOp = dyn_cast<arith::ConstantOp>(op)) return cOp;
461- if (op->getNumOperands () != 1 ) return {};
462- current = op->getOperand (0 );
463- }
464- return {};
465- }();
466- if (!constantVectorSource) {
457+ Value currentTraversalValue = writeOp.getVector ();
458+ arith::ConstantOp vectorSource;
459+ while (Operation *op = currentTraversalValue.getDefiningOp ()) {
460+ if (auto cOp = dyn_cast<arith::ConstantOp>(op)) vectorSource = cOp;
461+ if (vectorSource || op->getNumOperands () != 1 ) break ;
462+ currentTraversalValue = op->getOperand (0 );
463+ }
464+ if (!vectorSource) {
467465 return rewriter.notifyMatchFailure (
468466 writeOp, " vector isn't derived from arith.constant" );
469467 }
470468
471469 // Get the splat value of the constant vector.
472- auto maybeSplat = [&]() -> FailureOr<Attribute> {
473- TypedAttr constantValue = constantVectorSource.getValue ();
474- auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
475- if (!splatAttr || !splatAttr.isSplat ()) {
476- return rewriter.notifyMatchFailure (writeOp, " constant isn't a splat" );
477- }
478- return splatAttr.getSplatValue <Attribute>();
479- }();
480- if (failed (maybeSplat)) return failure ();
481- Attribute splat = maybeSplat.value ();
470+ TypedAttr constantValue = vectorSource.getValue ();
471+ auto splatAttr = dyn_cast<SplatElementsAttr>(constantValue);
472+ if (!splatAttr || !splatAttr.isSplat ()) {
473+ return rewriter.notifyMatchFailure (writeOp, " constant isn't a splat" );
474+ }
475+ Attribute splat = splatAttr.getSplatValue <Attribute>();
482476
483477 int64_t bytesPerWrite = deviceModel.getPreferredLoadBytes ();
484478
@@ -504,8 +498,8 @@ struct SerializeSplatTransferReadWithTargetLoadSize
504498 auto createTransferWrite = [&](uint32_t n, Value offset) {
505499 VectorType type = VectorType::get ({n}, elementType);
506500 DenseElementsAttr attr = DenseElementsAttr::get (type, splat);
507- auto newConstantOp = rewriter. create <arith::ConstantOp>(
508- constantVectorSource .getLoc (), type, attr);
501+ auto newConstantOp =
502+ rewriter. create <arith::ConstantOp>(vectorSource .getLoc (), type, attr);
509503 rewriter.create <vector::TransferWriteOp>(
510504 writeOp.getLoc (), newConstantOp.getResult (), writeDestination, offset,
511505 writeOp.getPermutationMapAttr (), writeOp.getInBoundsAttr ());
0 commit comments