diff --git a/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h b/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h index 39977768fb..40f3d08a17 100644 --- a/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h +++ b/include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h @@ -12,6 +12,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/PatternMatch.h" #include "shardy/dialect/sdy/ir/dialect.h" +#include "stablehlo/dialect/StablehloOps.h" #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" @@ -28,53 +29,36 @@ class MeshSharding { llvm::Expected convertGSPMDShardingToMeshSharding(StringRef shardingStr); + // Check and update arg sharding attribute and determine if mesh_shard op + // needs to be created or not. + bool checkAndUpdateGSPMDArgSharding(mlir::PatternRewriter &rewriter, + mlir::stablehlo::CustomCallOp srcOp, + mlir::StringAttr shardingAttr); + + // Check and update ret sharding attribute and determine if mesh_shard op + // needs to be created or not. + bool checkAndUpdateGSPMDRetSharding(mlir::PatternRewriter &rewriter, + mlir::stablehlo::CustomCallOp srcOp, + mlir::StringAttr shardingAttr); + // Convert sdy.sharding to meshSharding. llvm::Expected convertSdyShardingToMeshSharding(mlir::sdy::TensorShardingAttr sdySharding, mlir::sdy::MeshAttr mesh, mlir::tt::MeshShardDirection direction); - // Check and update function arg sharding - template - void checkAndUpdateFuncArgSharding(mlir::PatternRewriter &rewriter, - mlir::func::FuncOp funcOp, uint64_t argNum, - AttrType shardingAttr, - llvm::StringRef argShardingStrRef) { - if (auto argShardingAttr = - funcOp.getArgAttrOfType(argNum, argShardingStrRef)) { - if (argShardingAttr == shardingAttr) { - setDummyShardingOp(); - rewriter.modifyOpInPlace( - funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); }); - } else { - llvm_unreachable( - "MeshSharding operation and function argument shardings " - "are different."); - } - } - } + // Check and update arg sharding attribute and determine if + // mesh_shard op needs to be created or not. + bool checkAndUpdateShardyArgSharding( + mlir::PatternRewriter &rewriter, mlir::func::FuncOp funcOp, + mlir::Value argOperand, mlir::sdy::TensorShardingAttr shardingAttr); - // Check and update function ret sharding - template - void checkAndUpdateFuncReturnSharding(mlir::PatternRewriter &rewriter, - mlir::func::FuncOp funcOp, - uint64_t retNum, AttrType shardingAttr, - llvm::StringRef retShardingStrRef) { - if (auto retShardingAttr = - funcOp.getResultAttrOfType(retNum, retShardingStrRef)) { - if (retShardingAttr == shardingAttr) { - setDummyShardingOp(); - rewriter.modifyOpInPlace(funcOp, [&]() { - funcOp.removeResultAttr( - retNum, - mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef)); - }); - } else { - llvm_unreachable("MeshSharding operation and function return shardings " - "are different."); - } - } - } + // Check and update ret sharding attribute and determine if mesh_shard op + // needs to be created or not. + bool + checkAndUpdateShardyRetSharding(mlir::PatternRewriter &rewriter, + mlir::func::FuncOp funcOp, uint64_t retIdx, + mlir::sdy::TensorShardingAttr shardingAttr); // Getter functions. mlir::tt::MeshShardDirection getShardDirection() const { @@ -103,14 +87,14 @@ class MeshSharding { meshShape = llvm::SmallVector{-1}; } - // Force dummy sharding op by setting shard_type to manual. The mesh_shard op - // will be ignored at runtime by simply copying input tensor to output. - void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; } + // Force dummy sharding op by setting shard_type to identity. The mesh_shard + // op will be ignored at runtime by simply copying input tensor to output. + void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Identity; } private: mlir::tt::MeshShardDirection shardDirection = mlir::tt::MeshShardDirection::ShardToFull; - mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Manual; + mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Identity; llvm::SmallVector shardShape{-1}; llvm::SmallVector shardDims{-1}; llvm::SmallVector meshShape{-1}; @@ -118,6 +102,47 @@ class MeshSharding { bool lastTileDimReplicate = false; }; +// Remove arg sharding and return true if it is found, otherwise return false. +template +bool checkAndRemoveFuncArgSharding(mlir::PatternRewriter &rewriter, + mlir::func::FuncOp funcOp, uint64_t argNum, + AttrType shardingAttr, + llvm::StringRef argShardingStrRef) { + if (auto argShardingAttr = + funcOp.getArgAttrOfType(argNum, argShardingStrRef)) { + if (argShardingAttr == shardingAttr) { + rewriter.modifyOpInPlace( + funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); }); + return true; + } + llvm_unreachable("MeshSharding operation and function argument shardings " + "are different."); + } + return false; +} + +// Remove ret sharding and return true if it is found, otherwise return false. +template +bool checkAndRemoveFuncReturnSharding(mlir::PatternRewriter &rewriter, + mlir::func::FuncOp funcOp, + uint64_t retIdx, AttrType shardingAttr, + llvm::StringRef retShardingStrRef) { + if (auto retShardingAttr = + funcOp.getResultAttrOfType(retIdx, retShardingStrRef)) { + if (retShardingAttr == shardingAttr) { + rewriter.modifyOpInPlace(funcOp, [&]() { + funcOp.removeResultAttr( + retIdx, + mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef)); + }); + return true; + } + llvm_unreachable("MeshSharding operation and function return shardings " + "are different."); + } + return false; +} + // Sharding related string definitions from open-xla // https://github.com/openxla/xla/blob/main/xla/service/spmd/shardy/constants.h diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td index 44d096ae87..12fbb86c28 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsEnums.td @@ -164,14 +164,14 @@ def TT_MeshShardDirection: I32EnumAttr<"MeshShardDirection", "TT MeshShardDirect let cppNamespace = "::mlir::tt"; } -def TT_MeshShardType_Manual : I32EnumAttrCase<"Manual", 0, "manual">; +def TT_MeshShardType_Identity : I32EnumAttrCase<"Identity", 0, "identity">; def TT_MeshShardType_Replicate : I32EnumAttrCase<"Replicate", 1, "replicate">; def TT_MeshShardType_Maximal : I32EnumAttrCase<"Maximal", 2, "maximal">; def TT_MeshShardType_Devices : I32EnumAttrCase<"Devices", 3, "devices">; def TT_MeshShardType: I32EnumAttr<"MeshShardType", "TT MeshShardType", [ - TT_MeshShardType_Manual, + TT_MeshShardType_Identity, TT_MeshShardType_Replicate, TT_MeshShardType_Maximal, TT_MeshShardType_Devices, diff --git a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td index d823d2499d..c4cf5fedf6 100644 --- a/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td +++ b/include/ttmlir/Dialect/TT/IR/TTOpsTypes.td @@ -464,6 +464,14 @@ def TT_MeshShardDirectionAttr : EnumAttr { + let summary = "MeshShard shard_type attribute in TT dialect"; + let description = [{ + Define sharded tensor data of mesh_shard op. + - Identity: input and output tensors are pre-sharded (same data) and no sharding is required. + - Replicate: all of the devices has full tensor (same data). + - Maximal: one or part of the devcices has full tensor (same data). + - Devices: all or part of the devices has sharded (partial) tensor (different data). + }]; let assemblyFormat = "`<` $value `>`"; } diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td index 381a3f2811..5d5699bc1c 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNOps.td @@ -1503,7 +1503,8 @@ def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> { let extraClassDeclaration = [{ wa::TTNNOperandsWorkarounds getOperandsWorkarounds() { - return wa::TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds(); + mlir::tt::ttnn::MeshShardOp op = mlir::cast(this->getOperation()); + return wa::TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds(op.getShardType()); } }]; diff --git a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h index 3e354f2bfd..4573f84615 100644 --- a/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h +++ b/include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h @@ -235,7 +235,8 @@ class TTNNOperandsWorkaroundsFactory { createFullOpOperandsWorkarounds(RankedTensorType outputType); // Create workarounds for mesh shard op operands. - static TTNNOperandsWorkarounds createMeshShardOpOperandsWorkarounds(); + static TTNNOperandsWorkarounds + createMeshShardOpOperandsWorkarounds(mlir::tt::MeshShardType shardType); // Create workarounds for concat op operands. static TTNNOperandsWorkarounds diff --git a/include/ttmlir/Target/TTNN/types.fbs b/include/ttmlir/Target/TTNN/types.fbs index 849586c751..96d34fb46b 100644 --- a/include/ttmlir/Target/TTNN/types.fbs +++ b/include/ttmlir/Target/TTNN/types.fbs @@ -24,7 +24,7 @@ enum MeshShardDirection: uint32 { } enum MeshShardType: uint32 { - Manual, + Identity, Replicate, Maximal, Devices, diff --git a/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp b/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp index 90b5273ef5..2037ba485e 100644 --- a/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp +++ b/lib/Conversion/StableHLOToTTIR/ShardingUtils.cpp @@ -133,7 +133,7 @@ llvm::Expected MeshSharding::determineGSPMDShardingDims() { // https://github.com/sdasgup3/stablehlo/blob/80082431d1af0933e6202ecc8a6f8801e039235b/docs/spec.md#sharding-attribute llvm::Expected MeshSharding::convertGSPMDShardingToMeshSharding(StringRef shardingStr) { - shardType = mlir::tt::MeshShardType::Manual; + shardType = mlir::tt::MeshShardType::Identity; lastTileDimReplicate = false; // Parse string and tokenize. @@ -148,21 +148,21 @@ MeshSharding::convertGSPMDShardingToMeshSharding(StringRef shardingStr) { for (auto str : shardingStrTokens) { if (str.contains("manual")) { // manual: already sharded, so no action is needed - if (shardType != tt::MeshShardType::Manual) { + if (shardType != tt::MeshShardType::Identity) { return llvm::createStringError(std::errc::invalid_argument, "Fail to parse GSPMD sharding."); } - setNonDevicesShardType(tt::MeshShardType::Manual); + setNonDevicesShardType(tt::MeshShardType::Identity); } else if (str.contains("replicated")) { // replicated: all devices have whole data - if (shardType != tt::MeshShardType::Manual) { + if (shardType != tt::MeshShardType::Identity) { return llvm::createStringError(std::errc::invalid_argument, "Fail to parse GSPMD sharding."); } setNonDevicesShardType(tt::MeshShardType::Replicate); } else if (str.contains("maximal")) { // maximal: one device has whole data - if (shardType != tt::MeshShardType::Manual) { + if (shardType != tt::MeshShardType::Identity) { return llvm::createStringError(std::errc::invalid_argument, "Fail to parse GSPMD sharding."); } @@ -181,7 +181,7 @@ MeshSharding::convertGSPMDShardingToMeshSharding(StringRef shardingStr) { deviceIds.push_back(d); } else if (str.consume_front("devices=")) { // other: "devices" detail sharding plan - if (shardType != tt::MeshShardType::Manual) { + if (shardType != tt::MeshShardType::Identity) { return llvm::createStringError(std::errc::invalid_argument, "Fail to parse GSPMD sharding."); } @@ -214,6 +214,81 @@ MeshSharding::convertGSPMDShardingToMeshSharding(StringRef shardingStr) { return true; } +// Check and remove arg sharding attribute and determine if mesh_shard op needs +// to be created or not. +bool MeshSharding::checkAndUpdateGSPMDArgSharding( + mlir::PatternRewriter &rewriter, mlir::stablehlo::CustomCallOp srcOp, + mlir::StringAttr shardingAttr) { + auto funcOp = srcOp->getParentOfType(); + bool foundArgSharding = false; + + if (auto blockArg = + mlir::dyn_cast(srcOp->getOperand(0))) { + auto argNum = blockArg.getArgNumber(); + foundArgSharding = checkAndRemoveFuncArgSharding( + rewriter, funcOp, argNum, shardingAttr, + mlir::tt::sharding_utils::kXlaShardingAttr); + } + + // If JAX expects pre-sharded input (foundArgSharding) and if it is replicate, + // do not create mesh_shard op as the input/output shapes are identical. + if (foundArgSharding && shardType == mlir::tt::MeshShardType::Replicate) { + return false; + } + + // If JAX expects pre-sharded input (foundArgSharding) and if it is not + // replicate, mesh_shard op with dummy operation is still necessary for + // input/output shape conversion purpose. + if (foundArgSharding) { + setDummyShardingOp(); + } + + return true; +} + +// Check and remove ret sharding attribute and determine if mesh_shard op needs +// to be created or not. +bool MeshSharding::checkAndUpdateGSPMDRetSharding( + mlir::PatternRewriter &rewriter, mlir::stablehlo::CustomCallOp srcOp, + mlir::StringAttr shardingAttr) { + auto funcOp = srcOp->getParentOfType(); + bool foundRetSharding = false; + + // Check if the GSPMD ShardToFull output is one of the return values. + if (auto *funcReturnOp = funcOp.getBody().front().getTerminator()) { + auto returnOperands = funcReturnOp->getOperands(); + auto returnOperandIt = llvm::find_if(returnOperands, [&](Value operand) { + return operand == srcOp->getResult(0); + }); + if (returnOperandIt != returnOperands.end()) { + auto retIdx = std::distance(returnOperands.begin(), returnOperandIt); + foundRetSharding = checkAndRemoveFuncReturnSharding( + rewriter, funcOp, retIdx, shardingAttr, + mlir::tt::sharding_utils::kXlaShardingAttr); + } + } + + // If JAX expects sharded output (foundRetSharding) and if it is replicate, do + // not create mesh_shard op. + + // TODO (wooseoklee) : Due to + // https://github.com/llvm/llvm-project/issues/122695, temporarily allow + // mesh_shard op even with replicate shard type. Once the issue is resolved, + // we need to enable this to remove the return sharding with replicate. + // if (foundRetSharding && shardType == mlir::tt::MeshShardType::Replicate) { + // return false; + // } + + // If JAX expects sharded output (foundRetSharding) and if it is not + // replicate, mesh_shard op with dummy operation is still necessary for + // input/output shape conversion purpose. + if (foundRetSharding) { + setDummyShardingOp(); + } + + return true; +} + // Convert sdy.sharding to meshSharding based on sdy::MeshAttr. llvm::Expected MeshSharding::convertSdyShardingToMeshSharding( sdy::TensorShardingAttr sdySharding, sdy::MeshAttr meshAttr, @@ -278,6 +353,62 @@ llvm::Expected MeshSharding::convertSdyShardingToMeshSharding( return true; } +// Check and remove arg sharding attribute and determine if mesh_shard op needs +// to be created or not. +bool MeshSharding::checkAndUpdateShardyArgSharding( + mlir::PatternRewriter &rewriter, mlir::func::FuncOp funcOp, + mlir::Value argOperand, mlir::sdy::TensorShardingAttr shardingAttr) { + + bool foundArgSharding = false; + if (auto blockArg = mlir::dyn_cast(argOperand)) { + auto argNum = blockArg.getArgNumber(); + foundArgSharding = + checkAndRemoveFuncArgSharding( + rewriter, funcOp, argNum, shardingAttr, mlir::sdy::kShardingAttr); + } + + // If JAX expects pre-sharded input (foundArgSharding) and if it is replicate, + // do not create mesh_shard op as the input/output shapes are identical. + if (foundArgSharding && shardType == mlir::tt::MeshShardType::Replicate) { + return false; + } + + // If JAX expects pre-sharded input (foundArgSharding) and if it is not + // replicate, mesh_shard op with dummy operation is still necessary for + // input/output shape conversion purpose. + if (foundArgSharding) { + setDummyShardingOp(); + } + + return true; +} + +// Check and remove ret sharding attribute and determine if mesh_shard op needs +// to be created or not. +bool MeshSharding::checkAndUpdateShardyRetSharding( + mlir::PatternRewriter &rewriter, mlir::func::FuncOp funcOp, uint64_t retIdx, + sdy::TensorShardingAttr shardingAttr) { + + bool foundRetSharding = + checkAndRemoveFuncReturnSharding( + rewriter, funcOp, retIdx, shardingAttr, mlir::sdy::kShardingAttr); + + // If JAX expects sharded output (foundRetSharding) and if it is replicate, do + // not create mesh_shard op. + if (foundRetSharding && shardType == mlir::tt::MeshShardType::Replicate) { + return false; + } + + // If JAX expects sharded output (foundRetSharding) and if it is not + // replicate, mesh_shard op with dummy operation is still necessary for + // input/output shape conversion purpose. + if (foundRetSharding) { + setDummyShardingOp(); + } + + return true; +} + } // namespace sharding_utils } // namespace tt } // namespace mlir diff --git a/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp index 18574727b2..ef99ae4f2b 100644 --- a/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/ShardyToTTIRPatterns.cpp @@ -20,6 +20,7 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Operation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "shardy/dialect/sdy/ir/constants.h" @@ -106,32 +107,36 @@ class ShardyToTTIRManualComputationOpConversionPattern } // JAX automatic sharding pre-shards input tensors and provides multiple - // buffers. Thus, mesh sharding operations should not shard the tensors - // twice if they are function arguments and pre-sharded by frontend. - if (auto blockArg = mlir::dyn_cast(globalOperand)) { - auto argNum = blockArg.getArgNumber(); - meshSharding - .checkAndUpdateFuncArgSharding( - rewriter, funcOp, argNum, argSharding, - mlir::sdy::kShardingAttr); - } - - auto outputType = mlir::cast( - getTypeConverter()->convertType(localArgType)); + // buffers. Thus, we have to check if mesh shard op is sharding the + // tensors twice. We create dummy mesh shard op if input and output + // shapes are different or not create mesh shard op if they are + // identical. + bool shouldCreateMeshShardOp = + meshSharding.checkAndUpdateShardyArgSharding( + rewriter, funcOp, globalOperand, argSharding); + if (shouldCreateMeshShardOp) { + auto outputType = mlir::cast( + getTypeConverter()->convertType(localArgType)); - auto meshShardOp = - ttmlir::utils::createDPSOp( - rewriter, loc, outputType, globalOperand, - meshSharding.getShardType(), meshSharding.getShardDirection(), - meshSharding.getShardShape(), meshSharding.getShardDims()); + auto meshShardOp = + ttmlir::utils::createDPSOp( + rewriter, loc, outputType, globalOperand, + meshSharding.getShardType(), meshSharding.getShardDirection(), + meshSharding.getShardShape(), meshSharding.getShardDims()); - fullToShardResults.push_back(meshShardOp.getResult()); + fullToShardResults.push_back(meshShardOp.getResult()); + } else { + // Do not create mesh shard op if input and output shapes are + // identical: frontend provides sharded input and shard type is + // replicate. + fullToShardResults.push_back(globalOperand); + } } // Add mesh_shard (ShardToFullShape) for outputs. rewriter.setInsertionPointAfter(srcOp); mlir::Operation *sdyReturn = getBodyTerminator(srcOp); - for (auto [retNum, args] : llvm::enumerate(llvm::zip_equal( + for (auto [retIdx, args] : llvm::enumerate(llvm::zip_equal( sdyReturn->getOpOperands(), srcOp.getOutShardings().getShardings(), srcOp.getResults()))) { auto [returnOperand, outSharding, opResult] = args; @@ -142,30 +147,36 @@ class ShardyToTTIRManualComputationOpConversionPattern return rewriter.notifyMatchFailure(srcOp, llvm::toString(std::move(e))); } - // JAX automatic sharding may expect pre-sharded output tensors. Thus, - // mesh sharding operations should not concat the tensors twice if - // frontent expects pre-sharded tensor. - meshSharding - .checkAndUpdateFuncReturnSharding( - rewriter, funcOp, retNum, outSharding, mlir::sdy::kShardingAttr); - - auto inputOperand = returnOperand.get(); - auto inputType = mlir::cast( - getTypeConverter()->convertType(inputOperand.getType())); - if (inputType != inputOperand.getType()) { - inputOperand.setType(inputType); - } + // JAX automatic sharding may expect pre-sharded output tensors. We should + // check and update mesh shard op to match frontend's expectation. We may + // create dummy mesh shard op even though frontend expect sharded return + // in case input and output shapes of mesh shard op are different. + bool shouldCreateMeshShardOp = + meshSharding.checkAndUpdateShardyRetSharding(rewriter, funcOp, retIdx, + outSharding); + if (shouldCreateMeshShardOp) { + auto inputOperand = returnOperand.get(); + auto inputType = mlir::cast( + getTypeConverter()->convertType(inputOperand.getType())); + if (inputType != inputOperand.getType()) { + inputOperand.setType(inputType); + } - auto outputType = mlir::cast( - getTypeConverter()->convertType(opResult.getType())); + auto outputType = mlir::cast( + getTypeConverter()->convertType(opResult.getType())); - auto meshShardOp = - ttmlir::utils::createDPSOp( - rewriter, loc, outputType, inputOperand, - meshSharding.getShardType(), meshSharding.getShardDirection(), - meshSharding.getShardShape(), meshSharding.getShardDims()); + auto meshShardOp = + ttmlir::utils::createDPSOp( + rewriter, loc, outputType, inputOperand, + meshSharding.getShardType(), meshSharding.getShardDirection(), + meshSharding.getShardShape(), meshSharding.getShardDims()); - rewriter.replaceAllUsesWith(opResult, meshShardOp.getResult()); + rewriter.replaceAllUsesWith(opResult, meshShardOp.getResult()); + } else { + // Do not create mesh shard op if input and output shapes are identical: + // frontend expects sharded return and shard type is replicate. + rewriter.replaceAllUsesWith(opResult, returnOperand.get()); + } } // Inline inner block ops. @@ -186,8 +197,8 @@ class ShardyToTTIRMeshOpConversionPattern llvm::LogicalResult matchAndRewrite(mlir::sdy::MeshOp srcOp, mlir::sdy::MeshOp::Adaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const override { - // The goal of this conversion is to extract hardware mesh information from - // sdy.mesh op and store it as module attribute. + // The main goal of this conversion is to extract hardware mesh information + // from sdy.mesh op and store it as module attribute. auto module = srcOp->getParentOfType(); if (!module) { llvm_unreachable( @@ -202,7 +213,36 @@ class ShardyToTTIRMeshOpConversionPattern } mlir::tt::utils::addMeshToModuleAttribute(rewriter, module, meshName, meshShape); + + // Before erasing MeshOp, visit public functions and erase argument sharding + // attributes that are not refered by ManualComputationOp. Ones that are + // refered by ManualComputationOp are properly handled by + // ShardyToTTIRManualComputationOpConversionPattern. + module->walk([&](mlir::func::FuncOp funcOp) { + if (funcOp.isPublic()) { + for (auto arg : funcOp.getArguments()) { + auto argIdx = arg.getArgNumber(); + auto argShardingAttr = + funcOp.getArgAttrOfType( + argIdx, mlir::sdy::kShardingAttr); + if (!argShardingAttr) { + continue; + } + if (llvm::any_of(arg.getUsers(), [&](mlir::Operation *user) { + return mlir::dyn_cast_if_present< + mlir::sdy::ManualComputationOp>(*user); + })) { + continue; + } + rewriter.modifyOpInPlace(funcOp, [&]() { + funcOp.removeArgAttr(argIdx, mlir::sdy::kShardingAttr); + }); + } + } + }); + rewriter.eraseOp(srcOp); + return llvm::success(); } }; diff --git a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp index de1c9b1ba9..a79f97d08e 100644 --- a/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp +++ b/lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp @@ -1798,7 +1798,6 @@ class StableHLOToTTIRCustomCallOpConversionPattern meshShape); } - auto funcOp = srcOp->getParentOfType(); if (callTargetName == mlir::tt::sharding_utils::kSPMDFullToShardShapeCallTargetName) { // @Sharding => @SPMDFullToShardShape pattern @@ -1817,37 +1816,33 @@ class StableHLOToTTIRCustomCallOpConversionPattern srcOp, "Requires operand to be defined by prior Sharding op."); } - // JAX automatic sharding may expect pre-sharded output tensors. Thus, - // mesh sharding operations should not concat the tensors twice if - // frontent expects pre-sharded tensor. - if (auto *funcReturnOp = funcOp.getBody().front().getTerminator()) { - auto returnOperands = funcReturnOp->getOperands(); - auto returnOperandIt = - llvm::find_if(returnOperands, [&](Value operand) { - return operand == srcOp->getResult(0); - }); - if (returnOperandIt != returnOperands.end()) { - auto retNum = std::distance(returnOperands.begin(), returnOperandIt); - meshSharding.checkAndUpdateFuncReturnSharding( - rewriter, funcOp, retNum, shardingAttr, - mlir::tt::sharding_utils::kXlaShardingAttr); - } - } - - auto outputType = mlir::cast( - getTypeConverter()->convertType(srcOp->getResult(0).getType())); - - ttmlir::utils::replaceOpWithNewDPSOp( - rewriter, srcOp, outputType, adaptor.getInputs().front(), - meshSharding.getShardType(), - mlir::tt::MeshShardDirection::ShardToFull, - meshSharding.getShardShape(), meshSharding.getShardDims()); + // JAX automatic sharding may expect pre-sharded output tensors. We should + // check and update mesh shard op to match frontend's expectation. We may + // create dummy mesh shard op even though frontend expect sharded return + // in case input and output shapes of mesh shard op are different. + bool shouldCreateMeshShardOp = + meshSharding.checkAndUpdateGSPMDRetSharding(rewriter, srcOp, + shardingAttr); + auto inputOperand = adaptor.getInputs().front(); + if (shouldCreateMeshShardOp) { + auto outputType = mlir::cast( + getTypeConverter()->convertType(srcOp->getResult(0).getType())); + ttmlir::utils::replaceOpWithNewDPSOp( + rewriter, srcOp, outputType, inputOperand, + meshSharding.getShardType(), + mlir::tt::MeshShardDirection::ShardToFull, + meshSharding.getShardShape(), meshSharding.getShardDims()); + } else { + // Do not create mesh shard op if input and output shapes are identical: + // frontend expects sharded return and shard type is replicate. + rewriter.replaceOp(srcOp, inputOperand); + } } else if (callTargetName == mlir::tt::sharding_utils::kShardingCustomCallTargetName) { - if (meshSharding.getShardType() == mlir::tt::MeshShardType::Manual) { + if (meshSharding.getShardType() == mlir::tt::MeshShardType::Identity) { // @Sharding => @SPMDShardToFullShape pattern - // "manual" sharding indicates no sharding is required. + // "identical" sharding indicates no sharding is required. rewriter.replaceOp(srcOp, srcOp->getOperand(0)); } else { // @Sharding => @SPMDFullToShardShape pattern @@ -1859,25 +1854,30 @@ class StableHLOToTTIRCustomCallOpConversionPattern } // JAX automatic sharding pre-shards input tensors and provides multiple - // buffers. Thus, mesh sharding operations should not shard the tensors - // twice if they are function arguments and pre-sharded by frontend. + // buffers. Thus, we have to check if mesh shard op is sharding the + // tensors twice. We create dummy mesh shard op if input and output + // shapes are different or not create mesh shard op if they are + // identical. + bool shouldCreateMeshShardOp = + meshSharding.checkAndUpdateGSPMDArgSharding(rewriter, srcOp, + shardingAttr); auto inputOperand = adaptor.getInputs().front(); - if (auto blockArg = mlir::dyn_cast(inputOperand)) { - auto argNum = blockArg.getArgNumber(); - meshSharding.checkAndUpdateFuncArgSharding( - rewriter, funcOp, argNum, shardingAttr, - mlir::tt::sharding_utils::kXlaShardingAttr); + if (shouldCreateMeshShardOp) { + auto outputType = + mlir::cast(getTypeConverter()->convertType( + fullToShardCustomCall->getResult(0).getType())); + + ttmlir::utils::replaceOpWithNewDPSOp( + rewriter, srcOp, outputType, inputOperand, + meshSharding.getShardType(), + mlir::tt::MeshShardDirection::FullToShard, + meshSharding.getShardShape(), meshSharding.getShardDims()); + } else { + // Do not create mesh shard op if input and output shapes are + // identical: frontend provides sharded input and shard type is + // replicate. + rewriter.replaceOp(srcOp, inputOperand); } - - auto outputType = - mlir::cast(getTypeConverter()->convertType( - fullToShardCustomCall->getResult(0).getType())); - - ttmlir::utils::replaceOpWithNewDPSOp( - rewriter, srcOp, outputType, inputOperand, - meshSharding.getShardType(), - mlir::tt::MeshShardDirection::FullToShard, - meshSharding.getShardShape(), meshSharding.getShardDims()); } } return success(); diff --git a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp index 6793a8f5e5..38f66bcc9c 100644 --- a/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp +++ b/lib/Dialect/TTNN/IR/TTNNWorkarounds.cpp @@ -220,9 +220,12 @@ TTNNOperandsWorkaroundsFactory::createFullOpOperandsWorkarounds( // Factory method to create a set of workarounds for mesh shard op input // operand. ttnn::MeshShardOp supports host tensors only TTNNOperandsWorkarounds -TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds() { +TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds( + mlir::tt::MeshShardType shardType) { wa::TTNNOperandWorkarounds sysMemWorkaround; - sysMemWorkaround.tensorBufferTypeWorkaround = BufferType::SystemMemory; + if (shardType != mlir::tt::MeshShardType::Identity) { + sysMemWorkaround.tensorBufferTypeWorkaround = BufferType::SystemMemory; + } return wa::TTNNOperandsWorkarounds::createEmptyTTNNOperandsWorkarounds() .addInputOperandWorkaround(sysMemWorkaround) .addOutputOperandWorkaround(sysMemWorkaround); diff --git a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp index e690b9696b..b23f7e3815 100644 --- a/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp +++ b/lib/Dialect/TTNN/Transforms/TTNNLayout.cpp @@ -80,6 +80,15 @@ createLayoutAttr(MLIRContext *ctx, GridAttr deviceGrid, RankedTensorType type, tensorGrid, memoryLayoutAttr, collapseDimsRef); } +inline bool shouldMeshShardOpForceSystemMemory(mlir::Operation *srcOp) { + auto meshShardOp = mlir::dyn_cast_if_present(srcOp); + if (meshShardOp && + meshShardOp.getShardType() != mlir::tt::MeshShardType::Identity) { + return true; + } + return false; +} + //===----------------------------------------------------------------------===// // To layout pass //===----------------------------------------------------------------------===// @@ -380,7 +389,7 @@ class TTNNLayoutDPSOperandsRewriter // handle canonicalization of toLayout ops (#2102). Currently the // workaround pass cannot detect redundant toLayout ops as a result of // forcing the output layout and removing them. - if (mlir::isa(op.getOperation())) { + if (shouldMeshShardOpForceSystemMemory(op.getOperation())) { modified = changeLayoutToHost(op, operand, rewriter, isDPSResult); continue; } @@ -600,7 +609,7 @@ class TTNNLayoutFuncInputOutputTypeRewriter bool shouldForceInputSystemMemory(BlockArgument arg) const { for (Operation *user : arg.getUsers()) { - if (mlir::isa(user)) { + if (shouldMeshShardOpForceSystemMemory(user)) { return true; } // For the weight input of the conv2d op, it specifically needs to be on @@ -619,7 +628,7 @@ class TTNNLayoutFuncInputOutputTypeRewriter if (!mlir::isa(operand.getType())) { continue; } - if (operand.getDefiningOp()) { + if (shouldMeshShardOpForceSystemMemory(operand.getDefiningOp())) { return true; } } @@ -668,7 +677,7 @@ class TTNNLayoutFuncReturnRewriter private: // Return op output should be on host if it's a result of a mesh shard op bool shouldForceSystemMemory(Value operandValue) const { - if (operandValue.getDefiningOp()) { + if (shouldMeshShardOpForceSystemMemory(operandValue.getDefiningOp())) { return true; } return false; diff --git a/lib/Target/TTNN/TTNNToFlatbuffer.cpp b/lib/Target/TTNN/TTNNToFlatbuffer.cpp index b4e95518ba..874ee9888e 100644 --- a/lib/Target/TTNN/TTNNToFlatbuffer.cpp +++ b/lib/Target/TTNN/TTNNToFlatbuffer.cpp @@ -824,8 +824,8 @@ createOp(FlatbufferObjectCache &cache, MeshShardOp op) { meshShardType = ::tt::target::ttnn::MeshShardType::Replicate; } else if (shardType == mlir::tt::MeshShardType::Devices) { meshShardType = ::tt::target::ttnn::MeshShardType::Devices; - } else if (shardType == mlir::tt::MeshShardType::Manual) { - meshShardType = ::tt::target::ttnn::MeshShardType::Manual; + } else if (shardType == mlir::tt::MeshShardType::Identity) { + meshShardType = ::tt::target::ttnn::MeshShardType::Identity; } else { llvm_unreachable("unhandled mesh_shard type"); } diff --git a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp index 585ea2530a..41c3f6cdc5 100644 --- a/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp +++ b/runtime/lib/ttnn/operations/ccl/mesh_shard.cpp @@ -97,15 +97,15 @@ void run(const ::tt::target::ttnn::MeshShardOp *op, ProgramContext &context) { DEBUG_ASSERT(::tt::runtime::ttnn::utils::isOnHost(input.storage_type()), "Input of ttnn::mesh_shard should be host tensor"); - // Regards manual sharding as no op assuming that the input tensor is + // Regards identity mesh shard type as no op assuming that the input tensor is // pre-sharded by frontend. Thus, no sharding is required, but need to makes // sure if the tensor is multi-device host tensor. - if (shardType == ::tt::target::ttnn::MeshShardType::Manual) { - LOG_ASSERT(input.storage_type() == - ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST, - "Input of mesh_shard with manual sharding must be MULTI DEVICE " - "HOST Storage. id:", - op->in()->global_id()); + if (shardType == ::tt::target::ttnn::MeshShardType::Identity) { + LOG_ASSERT( + input.storage_type() == ::tt::tt_metal::StorageType::MULTI_DEVICE_HOST, + "Input of mesh_shard with identity shard_type must be MULTI DEVICE " + "HOST Storage. id:", + op->in()->global_id()); tensorPool.insertAndValidate(op->out(), input); return; } diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir index 6abccf1c93..96ca6964ef 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_gspmd.mlir @@ -1826,7 +1826,7 @@ module @jit_negative_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_r // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array - // CHECK-SAME: shard_type = #tt.shard_type + // CHECK-SAME: shard_type = #tt.shard_type %2 = call @shmap_body(%1) : (tensor<256x128xf32>) -> tensor<256x128xf32> %3 = stablehlo.custom_call @Sharding(%2) {mhlo.sharding = "{manual}"} : (tensor<256x128xf32>) -> tensor<256x128xf32> %4 = stablehlo.custom_call @SPMDShardToFullShape(%3) {mhlo.sharding = "{replicated}"} : (tensor<256x128xf32>) -> tensor<256x128xf32> @@ -1834,7 +1834,7 @@ module @jit_negative_basic attributes {mhlo.num_partitions = 2 : i32, mhlo.num_r // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array - // CHECK-SAME: shard_type = #tt.shard_type + // CHECK-SAME: shard_type = #tt.shard_type return %4 : tensor<256x128xf32> } func.func private @shmap_body(%arg0: tensor<256x128xf32>) -> (tensor<256x128xf32> {jax.result_info = "[None, None]"}) { diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir index dd9b8f8539..67262d7ab6 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/ccl_ops_shardy.mlir @@ -318,12 +318,12 @@ module @jit_matmul_shardy_automatic attributes {mhlo.num_partitions = 8 : i32, m // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: = "ttir.all_reduce" // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array @@ -353,15 +353,15 @@ module @jit_matmul_shardy1 attributes {mhlo.num_partitions = 8 : i32, mhlo.num_r // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: = "ttir.all_reduce" // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir index 2acb787afe..e38502ee99 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_gspmd.mlir @@ -116,75 +116,15 @@ module @jit_loss_dp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas // CHECK-LABEL @main // CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir index facbd69166..7e54cbe878 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_dp_shardy.mlir @@ -87,75 +87,15 @@ module @jit_loss_dp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas // CHECK-LABEL @main // CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" -// CHECK-SAME: shard_dims = array -// CHECK-SAME: shard_direction = #tt.shard_direction -// CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type -// CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir index 94071c41b6..167de4756c 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_gspmd.mlir @@ -132,72 +132,72 @@ module @jit_loss_fsdp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replic // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir index 8bbe2417ee..7099031924 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_shardy.mlir @@ -102,72 +102,72 @@ module @jit_loss_fsdp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replic // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction // CHECK-SAME: shard_shape = array -// CHECK-SAME: shard_type = #tt.shard_type +// CHECK-SAME: shard_type = #tt.shard_type // CHECK: "ttir.mesh_shard" // CHECK-SAME: shard_dims = array // CHECK-SAME: shard_direction = #tt.shard_direction diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir index f696f4b268..b15c864784 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_gspmd.mlir @@ -164,3 +164,78 @@ module @jit_loss_fsdp_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_rep } // CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir index 5ddd3fb180..a5d798b545 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_fsdp_tp_shardy.mlir @@ -1,6 +1,5 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s -// UNSUPPORTED: true module @jit_loss_fsdp_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=2, "y"=4]> @@ -134,3 +133,78 @@ module @jit_loss_fsdp_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_rep } // CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_pipeline_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_pipeline_gspmd.mlir new file mode 100644 index 0000000000..2391e483eb --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_pipeline_gspmd.mlir @@ -0,0 +1,589 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s +// UNSUPPORTED: true + +module @jit_loss_pp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<784x128xf32> {mhlo.sharding = "{replicated}"}, %arg1: tensor<128xf32> {mhlo.sharding = "{replicated}"}, %arg2: tensor<4x128x128xf32> {mhlo.sharding = "{devices=[2,1,1,4]<=[8] last_tile_dim_replicate}"}, %arg3: tensor<4x128xf32> {mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"}, %arg4: tensor<128x8xf32> {mhlo.sharding = "{replicated}"}, %arg5: tensor<8xf32> {mhlo.sharding = "{replicated}"}, %arg6: tensor<32x784xf32> {mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"}, %arg7: tensor<32x8xf32> {mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"}) -> (tensor {jax.result_info = ""}) { + %0 = stablehlo.custom_call @Sharding(%arg0) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<784x128xf32>) -> tensor<784x128xf32> + %2 = stablehlo.custom_call @Sharding(%arg1) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<128xf32>) -> tensor<128xf32> + %3 = stablehlo.custom_call @SPMDFullToShardShape(%2) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128xf32>) -> tensor<128xf32> + %4 = stablehlo.custom_call @Sharding(%arg2) {backend_config = "", mhlo.sharding = "{devices=[2,1,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<4x128x128xf32>) -> tensor<4x128x128xf32> + %5 = stablehlo.custom_call @SPMDFullToShardShape(%4) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<4x128x128xf32>) -> tensor<2x128x128xf32> + %6 = stablehlo.custom_call @Sharding(%arg3) {backend_config = "", mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<4x128xf32>) -> tensor<4x128xf32> + %7 = stablehlo.custom_call @SPMDFullToShardShape(%6) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<4x128xf32>) -> tensor<2x128xf32> + %8 = stablehlo.custom_call @Sharding(%arg4) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %9 = stablehlo.custom_call @SPMDFullToShardShape(%8) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<128x8xf32>) -> tensor<128x8xf32> + %10 = stablehlo.custom_call @Sharding(%arg5) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor<8xf32>) -> tensor<8xf32> + %11 = stablehlo.custom_call @SPMDFullToShardShape(%10) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<8xf32>) -> tensor<8xf32> + %12 = stablehlo.custom_call @Sharding(%arg6) {backend_config = "", mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<32x784xf32>) -> tensor<32x784xf32> + %13 = stablehlo.custom_call @SPMDFullToShardShape(%12) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x784xf32>) -> tensor<16x784xf32> + %14 = stablehlo.custom_call @Sharding(%arg7) {backend_config = "", mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<32x8xf32>) -> tensor<32x8xf32> + %15 = stablehlo.custom_call @SPMDFullToShardShape(%14) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor<32x8xf32>) -> tensor<16x8xf32> + %16 = call @shmap_body(%1, %3, %5, %7, %9, %11, %13, %15) : (tensor<784x128xf32>, tensor<128xf32>, tensor<2x128x128xf32>, tensor<2x128xf32>, tensor<128x8xf32>, tensor<8xf32>, tensor<16x784xf32>, tensor<16x8xf32>) -> tensor + %17 = stablehlo.custom_call @Sharding(%16) {backend_config = "", mhlo.sharding = "{manual}"} : (tensor) -> tensor + %18 = stablehlo.custom_call @SPMDShardToFullShape(%17) {backend_config = "", mhlo.sharding = "{replicated}"} : (tensor) -> tensor + return %18 : tensor + } + func.func private @shmap_body(%arg0: tensor<784x128xf32>, %arg1: tensor<128xf32>, %arg2: tensor<2x128x128xf32>, %arg3: tensor<2x128xf32>, %arg4: tensor<128x8xf32>, %arg5: tensor<8xf32>, %arg6: tensor<16x784xf32>, %arg7: tensor<16x8xf32>) -> (tensor {jax.result_info = "[]"}) { + %0 = stablehlo.reshape %arg6 : (tensor<16x784xf32>) -> tensor<2x8x784xf32> + %1 = stablehlo.dot_general %0, %arg0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x8x784xf32>, tensor<784x128xf32>) -> tensor<2x8x128xf32> + %2 = stablehlo.broadcast_in_dim %arg1, dims = [2] : (tensor<128xf32>) -> tensor<1x1x128xf32> + %3 = stablehlo.broadcast_in_dim %2, dims = [0, 1, 2] : (tensor<1x1x128xf32>) -> tensor<2x8x128xf32> + %4 = stablehlo.add %1, %3 : tensor<2x8x128xf32> + %5 = call @relu(%4) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %c = stablehlo.constant dense<4> : tensor + %c_0 = stablehlo.constant dense<2> : tensor + %6 = stablehlo.partition_id : tensor + %7 = stablehlo.divide %6, %c : tensor + %8 = stablehlo.remainder %7, %c_0 : tensor + %9 = stablehlo.convert %8 : (tensor) -> tensor + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %10 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %cst_1 = stablehlo.constant dense<0x7FC00000> : tensor + %11 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x8x128xf32> + %12 = stablehlo.multiply %10, %11 : tensor<2x8x128xf32> + %13 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %14 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x8x128xf32> + %15 = stablehlo.multiply %13, %14 : tensor<2x8x128xf32> + %c_2 = stablehlo.constant dense<0> : tensor + %16 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %17 = stablehlo.slice %5 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %18 = stablehlo.reshape %17 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %19 = stablehlo.slice %15 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %20 = stablehlo.reshape %19 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %21 = call @_where(%16, %18, %20) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %22 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %23 = "stablehlo.scatter"(%15, %22, %21) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %24 = stablehlo.dot_general %23, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %25 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %26 = stablehlo.broadcast_in_dim %25, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %27 = stablehlo.add %24, %26 : tensor<2x8x128xf32> + %28 = call @relu_0(%27) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %c_3 = stablehlo.constant dense<1> : tensor + %29 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %c_4 = stablehlo.constant dense<-1> : tensor + %c_5 = stablehlo.constant dense<0> : tensor + %30 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %c_6 = stablehlo.constant dense<-1> : tensor + %c_7 = stablehlo.constant dense<2> : tensor + %31 = stablehlo.add %c_6, %c_7 : tensor + %32 = stablehlo.select %30, %31, %c_4 : tensor, tensor + %33 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %c_8 = stablehlo.constant dense<8> : tensor + %34 = stablehlo.add %c_2, %c_8 : tensor + %35 = stablehlo.select %33, %34, %c_5 : tensor, tensor + %36 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %c_9 = stablehlo.constant dense<128> : tensor + %37 = stablehlo.add %c_2, %c_9 : tensor + %38 = stablehlo.select %36, %37, %c_5 : tensor, tensor + %39 = stablehlo.dynamic_slice %28, %32, %35, %38, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %40 = stablehlo.reshape %39 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %41 = stablehlo.slice %12 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %42 = stablehlo.reshape %41 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %43 = call @_where_1(%29, %40, %42) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %44 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %45 = "stablehlo.scatter"(%12, %44, %43) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %46 = call @_roll_static(%28) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %47 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %48 = stablehlo.add %c_6, %c_7 : tensor + %49 = stablehlo.select %47, %48, %c_4 : tensor, tensor + %50 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %51 = stablehlo.add %c_2, %c_8 : tensor + %52 = stablehlo.select %50, %51, %c_5 : tensor, tensor + %53 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %54 = stablehlo.add %c_2, %c_9 : tensor + %55 = stablehlo.select %53, %54, %c_5 : tensor, tensor + %56 = stablehlo.dynamic_slice %28, %49, %52, %55, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %57 = stablehlo.reshape %56 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %58 = "stablehlo.collective_permute"(%57) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %59 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %60 = "stablehlo.scatter"(%46, %59, %58) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %61 = "stablehlo.collective_permute"(%45) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %62 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %63 = stablehlo.slice %5 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %64 = stablehlo.reshape %63 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %65 = stablehlo.slice %60 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %66 = stablehlo.reshape %65 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %67 = call @_where_2(%62, %64, %66) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %68 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %69 = "stablehlo.scatter"(%60, %68, %67) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %70 = stablehlo.dot_general %69, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %71 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %72 = stablehlo.broadcast_in_dim %71, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %73 = stablehlo.add %70, %72 : tensor<2x8x128xf32> + %74 = call @relu_3(%73) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %75 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %76 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %77 = stablehlo.add %c_6, %c_7 : tensor + %78 = stablehlo.select %76, %77, %c_4 : tensor, tensor + %79 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %80 = stablehlo.add %c_2, %c_8 : tensor + %81 = stablehlo.select %79, %80, %c_5 : tensor, tensor + %82 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %83 = stablehlo.add %c_2, %c_9 : tensor + %84 = stablehlo.select %82, %83, %c_5 : tensor, tensor + %85 = stablehlo.dynamic_slice %74, %78, %81, %84, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %86 = stablehlo.reshape %85 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %87 = stablehlo.slice %61 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %88 = stablehlo.reshape %87 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %89 = call @_where_4(%75, %86, %88) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %90 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %91 = "stablehlo.scatter"(%61, %90, %89) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %92 = call @_roll_static_5(%74) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %93 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %94 = stablehlo.add %c_6, %c_7 : tensor + %95 = stablehlo.select %93, %94, %c_4 : tensor, tensor + %96 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %97 = stablehlo.add %c_2, %c_8 : tensor + %98 = stablehlo.select %96, %97, %c_5 : tensor, tensor + %99 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %100 = stablehlo.add %c_2, %c_9 : tensor + %101 = stablehlo.select %99, %100, %c_5 : tensor, tensor + %102 = stablehlo.dynamic_slice %74, %95, %98, %101, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %103 = stablehlo.reshape %102 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %104 = "stablehlo.collective_permute"(%103) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %105 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %106 = "stablehlo.scatter"(%92, %105, %104) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %107 = "stablehlo.collective_permute"(%5) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %108 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %109 = stablehlo.slice %107 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %110 = stablehlo.reshape %109 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %111 = stablehlo.slice %106 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %112 = stablehlo.reshape %111 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %113 = call @_where_6(%108, %110, %112) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %114 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %115 = "stablehlo.scatter"(%106, %114, %113) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %116 = stablehlo.dot_general %115, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %117 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %118 = stablehlo.broadcast_in_dim %117, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %119 = stablehlo.add %116, %118 : tensor<2x8x128xf32> + %120 = call @relu_7(%119) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %121 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %122 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %123 = stablehlo.add %c_6, %c_7 : tensor + %124 = stablehlo.select %122, %123, %c_4 : tensor, tensor + %125 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %126 = stablehlo.add %c_2, %c_8 : tensor + %127 = stablehlo.select %125, %126, %c_5 : tensor, tensor + %128 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %129 = stablehlo.add %c_2, %c_9 : tensor + %130 = stablehlo.select %128, %129, %c_5 : tensor, tensor + %131 = stablehlo.dynamic_slice %120, %124, %127, %130, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %132 = stablehlo.reshape %131 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %133 = stablehlo.slice %91 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %134 = stablehlo.reshape %133 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %135 = call @_where_8(%121, %132, %134) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %136 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %137 = "stablehlo.scatter"(%91, %136, %135) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %138 = call @_roll_static_9(%120) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %139 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %140 = stablehlo.add %c_6, %c_7 : tensor + %141 = stablehlo.select %139, %140, %c_4 : tensor, tensor + %142 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %143 = stablehlo.add %c_2, %c_8 : tensor + %144 = stablehlo.select %142, %143, %c_5 : tensor, tensor + %145 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %146 = stablehlo.add %c_2, %c_9 : tensor + %147 = stablehlo.select %145, %146, %c_5 : tensor, tensor + %148 = stablehlo.dynamic_slice %120, %141, %144, %147, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %149 = stablehlo.reshape %148 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %150 = "stablehlo.collective_permute"(%149) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %151 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %152 = "stablehlo.scatter"(%138, %151, %150) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %153 = "stablehlo.collective_permute"(%137) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %154 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %155 = stablehlo.slice %107 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %156 = stablehlo.reshape %155 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %157 = stablehlo.slice %152 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %158 = stablehlo.reshape %157 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %159 = call @_where_10(%154, %156, %158) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %160 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %161 = "stablehlo.scatter"(%152, %160, %159) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %162 = stablehlo.dot_general %161, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %163 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %164 = stablehlo.broadcast_in_dim %163, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %165 = stablehlo.add %162, %164 : tensor<2x8x128xf32> + %166 = call @relu_11(%165) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %167 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %168 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %169 = stablehlo.add %c_6, %c_7 : tensor + %170 = stablehlo.select %168, %169, %c_4 : tensor, tensor + %171 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %172 = stablehlo.add %c_2, %c_8 : tensor + %173 = stablehlo.select %171, %172, %c_5 : tensor, tensor + %174 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %175 = stablehlo.add %c_2, %c_9 : tensor + %176 = stablehlo.select %174, %175, %c_5 : tensor, tensor + %177 = stablehlo.dynamic_slice %166, %170, %173, %176, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %178 = stablehlo.reshape %177 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %179 = stablehlo.slice %153 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %180 = stablehlo.reshape %179 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %181 = call @_where_12(%167, %178, %180) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %182 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %183 = "stablehlo.scatter"(%153, %182, %181) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %184 = call @_roll_static_13(%166) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %185 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %186 = stablehlo.add %c_6, %c_7 : tensor + %187 = stablehlo.select %185, %186, %c_4 : tensor, tensor + %188 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %189 = stablehlo.add %c_2, %c_8 : tensor + %190 = stablehlo.select %188, %189, %c_5 : tensor, tensor + %191 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %192 = stablehlo.add %c_2, %c_9 : tensor + %193 = stablehlo.select %191, %192, %c_5 : tensor, tensor + %194 = stablehlo.dynamic_slice %166, %187, %190, %193, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %195 = stablehlo.reshape %194 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %196 = "stablehlo.collective_permute"(%195) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %197 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %198 = "stablehlo.scatter"(%184, %197, %196) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %199 = "stablehlo.collective_permute"(%107) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %200 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %201 = stablehlo.slice %199 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %202 = stablehlo.reshape %201 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %203 = stablehlo.slice %198 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %204 = stablehlo.reshape %203 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %205 = call @_where_14(%200, %202, %204) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %206 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %207 = "stablehlo.scatter"(%198, %206, %205) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %208 = stablehlo.dot_general %207, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %209 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %210 = stablehlo.broadcast_in_dim %209, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %211 = stablehlo.add %208, %210 : tensor<2x8x128xf32> + %212 = call @relu_15(%211) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %213 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %214 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %215 = stablehlo.add %c_6, %c_7 : tensor + %216 = stablehlo.select %214, %215, %c_4 : tensor, tensor + %217 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %218 = stablehlo.add %c_2, %c_8 : tensor + %219 = stablehlo.select %217, %218, %c_5 : tensor, tensor + %220 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %221 = stablehlo.add %c_2, %c_9 : tensor + %222 = stablehlo.select %220, %221, %c_5 : tensor, tensor + %223 = stablehlo.dynamic_slice %212, %216, %219, %222, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %224 = stablehlo.reshape %223 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %225 = stablehlo.slice %183 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %226 = stablehlo.reshape %225 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %227 = call @_where_16(%213, %224, %226) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %228 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %229 = "stablehlo.scatter"(%183, %228, %227) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %230 = call @_roll_static_17(%212) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %231 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %232 = stablehlo.add %c_6, %c_7 : tensor + %233 = stablehlo.select %231, %232, %c_4 : tensor, tensor + %234 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %235 = stablehlo.add %c_2, %c_8 : tensor + %236 = stablehlo.select %234, %235, %c_5 : tensor, tensor + %237 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %238 = stablehlo.add %c_2, %c_9 : tensor + %239 = stablehlo.select %237, %238, %c_5 : tensor, tensor + %240 = stablehlo.dynamic_slice %212, %233, %236, %239, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %241 = stablehlo.reshape %240 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %242 = "stablehlo.collective_permute"(%241) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %243 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %244 = "stablehlo.scatter"(%230, %243, %242) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %245 = "stablehlo.collective_permute"(%229) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %246 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %247 = stablehlo.slice %199 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %248 = stablehlo.reshape %247 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %249 = stablehlo.slice %244 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %250 = stablehlo.reshape %249 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %251 = call @_where_18(%246, %248, %250) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %252 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %253 = "stablehlo.scatter"(%244, %252, %251) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %254 = stablehlo.dot_general %253, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %255 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %256 = stablehlo.broadcast_in_dim %255, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %257 = stablehlo.add %254, %256 : tensor<2x8x128xf32> + %258 = call @relu_19(%257) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %259 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %260 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %261 = stablehlo.add %c_6, %c_7 : tensor + %262 = stablehlo.select %260, %261, %c_4 : tensor, tensor + %263 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %264 = stablehlo.add %c_2, %c_8 : tensor + %265 = stablehlo.select %263, %264, %c_5 : tensor, tensor + %266 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %267 = stablehlo.add %c_2, %c_9 : tensor + %268 = stablehlo.select %266, %267, %c_5 : tensor, tensor + %269 = stablehlo.dynamic_slice %258, %262, %265, %268, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %270 = stablehlo.reshape %269 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %271 = stablehlo.slice %245 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %272 = stablehlo.reshape %271 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %273 = call @_where_20(%259, %270, %272) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %274 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %275 = "stablehlo.scatter"(%245, %274, %273) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %276 = call @_roll_static_21(%258) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %277 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %278 = stablehlo.add %c_6, %c_7 : tensor + %279 = stablehlo.select %277, %278, %c_4 : tensor, tensor + %280 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %281 = stablehlo.add %c_2, %c_8 : tensor + %282 = stablehlo.select %280, %281, %c_5 : tensor, tensor + %283 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %284 = stablehlo.add %c_2, %c_9 : tensor + %285 = stablehlo.select %283, %284, %c_5 : tensor, tensor + %286 = stablehlo.dynamic_slice %258, %279, %282, %285, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %287 = stablehlo.reshape %286 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %288 = "stablehlo.collective_permute"(%287) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %289 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %290 = "stablehlo.scatter"(%276, %289, %288) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %291 = "stablehlo.collective_permute"(%199) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %292 = stablehlo.compare EQ, %9, %c_2, SIGNED : (tensor, tensor) -> tensor + %293 = stablehlo.slice %291 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %294 = stablehlo.reshape %293 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %295 = stablehlo.slice %290 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %296 = stablehlo.reshape %295 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %297 = call @_where_22(%292, %294, %296) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %298 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %299 = "stablehlo.scatter"(%290, %298, %297) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %300 = stablehlo.dot_general %299, %arg2, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %301 = stablehlo.broadcast_in_dim %arg3, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %302 = stablehlo.broadcast_in_dim %301, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %303 = stablehlo.add %300, %302 : tensor<2x8x128xf32> + %304 = call @relu_23(%303) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %305 = stablehlo.compare EQ, %9, %c_3, SIGNED : (tensor, tensor) -> tensor + %306 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %307 = stablehlo.add %c_6, %c_7 : tensor + %308 = stablehlo.select %306, %307, %c_4 : tensor, tensor + %309 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %310 = stablehlo.add %c_2, %c_8 : tensor + %311 = stablehlo.select %309, %310, %c_5 : tensor, tensor + %312 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %313 = stablehlo.add %c_2, %c_9 : tensor + %314 = stablehlo.select %312, %313, %c_5 : tensor, tensor + %315 = stablehlo.dynamic_slice %304, %308, %311, %314, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %316 = stablehlo.reshape %315 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %317 = stablehlo.slice %275 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %318 = stablehlo.reshape %317 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %319 = call @_where_24(%305, %316, %318) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %320 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %321 = "stablehlo.scatter"(%275, %320, %319) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + stablehlo.return %arg9 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %322 = "stablehlo.collective_permute"(%321) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %323 = "stablehlo.collective_permute"(%322) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %324 = stablehlo.dot_general %323, %arg4, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<128x8xf32>) -> tensor<2x8x8xf32> + %325 = stablehlo.broadcast_in_dim %arg5, dims = [2] : (tensor<8xf32>) -> tensor<1x1x8xf32> + %326 = stablehlo.broadcast_in_dim %325, dims = [0, 1, 2] : (tensor<1x1x8xf32>) -> tensor<2x8x8xf32> + %327 = stablehlo.add %324, %326 : tensor<2x8x8xf32> + %328 = stablehlo.reshape %327 : (tensor<2x8x8xf32>) -> tensor<16x8xf32> + %329 = stablehlo.subtract %328, %arg7 : tensor<16x8xf32> + %330 = stablehlo.multiply %329, %329 : tensor<16x8xf32> + %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor + %331 = stablehlo.reduce(%330 init: %cst_10) applies stablehlo.add across dimensions = [1] : (tensor<16x8xf32>, tensor) -> tensor<16xf32> + %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor + %332 = stablehlo.reduce(%331 init: %cst_11) applies stablehlo.add across dimensions = [0] : (tensor<16xf32>, tensor) -> tensor + %cst_12 = stablehlo.constant dense<1.600000e+01> : tensor + %333 = stablehlo.divide %332, %cst_12 : tensor + %334 = "stablehlo.all_reduce"(%333) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg8: tensor, %arg9: tensor): + %336 = stablehlo.add %arg8, %arg9 : tensor + stablehlo.return %336 : tensor + }) : (tensor) -> tensor + %cst_13 = stablehlo.constant dense<2.000000e+00> : tensor + %335 = stablehlo.divide %334, %cst_13 : tensor + return %335 : tensor + } + func.func private @relu(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_0(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_1(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_2(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_3(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_4(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_5(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_6(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_7(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_8(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_9(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_10(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_11(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_12(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_13(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_14(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_15(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_16(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_17(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_18(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_19(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_20(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_21(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_22(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_23(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_24(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } +} + +// CHECK-LABEL @main diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_pipeline_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_pipeline_shardy.mlir new file mode 100644 index 0000000000..7927b95c67 --- /dev/null +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_pipeline_shardy.mlir @@ -0,0 +1,571 @@ +// REQUIRES: stablehlo +// RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s +// UNSUPPORTED: true + +module @jit_loss_pp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { + sdy.mesh @mesh = <["stages"=2, "y"=4]> + func.func public @main(%arg0: tensor<784x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg1: tensor<128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg2: tensor<4x128x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"stages"}, {}, {}]>}, %arg3: tensor<4x128xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"stages"}, {}]>}, %arg4: tensor<128x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}, {}]>}, %arg5: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{}]>}, %arg6: tensor<32x784xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"stages"}, {}]>}, %arg7: tensor<32x8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"stages"}, {}]>}) -> (tensor {jax.result_info = ""}) { + %0 = sdy.manual_computation(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) in_shardings=[<@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{"stages"}, {}, {}]>, <@mesh, [{"stages"}, {}]>, <@mesh, [{}, {}]>, <@mesh, [{}]>, <@mesh, [{"stages"}, {}]>, <@mesh, [{"stages"}, {}]>] out_shardings=[<@mesh, []>] manual_axes={"stages", "y"} (%arg8: tensor<784x128xf32>, %arg9: tensor<128xf32>, %arg10: tensor<2x128x128xf32>, %arg11: tensor<2x128xf32>, %arg12: tensor<128x8xf32>, %arg13: tensor<8xf32>, %arg14: tensor<16x784xf32>, %arg15: tensor<16x8xf32>) { + %1 = stablehlo.reshape %arg14 : (tensor<16x784xf32>) -> tensor<2x8x784xf32> + %2 = stablehlo.dot_general %1, %arg8, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x8x784xf32>, tensor<784x128xf32>) -> tensor<2x8x128xf32> + %3 = stablehlo.broadcast_in_dim %arg9, dims = [2] : (tensor<128xf32>) -> tensor<1x1x128xf32> + %4 = stablehlo.broadcast_in_dim %3, dims = [0, 1, 2] : (tensor<1x1x128xf32>) -> tensor<2x8x128xf32> + %5 = stablehlo.add %2, %4 : tensor<2x8x128xf32> + %6 = func.call @relu(%5) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %c = stablehlo.constant dense<4> : tensor + %c_0 = stablehlo.constant dense<2> : tensor + %7 = stablehlo.partition_id : tensor + %8 = stablehlo.divide %7, %c : tensor + %9 = stablehlo.remainder %8, %c_0 : tensor + %10 = stablehlo.convert %9 : (tensor) -> tensor + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %11 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %cst_1 = stablehlo.constant dense<0x7FC00000> : tensor + %12 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x8x128xf32> + %13 = stablehlo.multiply %11, %12 : tensor<2x8x128xf32> + %14 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %15 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor) -> tensor<2x8x128xf32> + %16 = stablehlo.multiply %14, %15 : tensor<2x8x128xf32> + %c_2 = stablehlo.constant dense<0> : tensor + %17 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %18 = stablehlo.slice %6 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %19 = stablehlo.reshape %18 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %20 = stablehlo.slice %16 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %21 = stablehlo.reshape %20 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %22 = func.call @_where(%17, %19, %21) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %23 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %24 = "stablehlo.scatter"(%16, %23, %22) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %25 = stablehlo.dot_general %24, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %26 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %27 = stablehlo.broadcast_in_dim %26, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %28 = stablehlo.add %25, %27 : tensor<2x8x128xf32> + %29 = func.call @relu_0(%28) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %c_3 = stablehlo.constant dense<1> : tensor + %30 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %c_4 = stablehlo.constant dense<-1> : tensor + %c_5 = stablehlo.constant dense<0> : tensor + %31 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %c_6 = stablehlo.constant dense<-1> : tensor + %c_7 = stablehlo.constant dense<2> : tensor + %32 = stablehlo.add %c_6, %c_7 : tensor + %33 = stablehlo.select %31, %32, %c_4 : tensor, tensor + %34 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %c_8 = stablehlo.constant dense<8> : tensor + %35 = stablehlo.add %c_2, %c_8 : tensor + %36 = stablehlo.select %34, %35, %c_5 : tensor, tensor + %37 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %c_9 = stablehlo.constant dense<128> : tensor + %38 = stablehlo.add %c_2, %c_9 : tensor + %39 = stablehlo.select %37, %38, %c_5 : tensor, tensor + %40 = stablehlo.dynamic_slice %29, %33, %36, %39, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %41 = stablehlo.reshape %40 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %42 = stablehlo.slice %13 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %43 = stablehlo.reshape %42 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %44 = func.call @_where_1(%30, %41, %43) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %45 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %46 = "stablehlo.scatter"(%13, %45, %44) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %47 = func.call @_roll_static(%29) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %48 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %49 = stablehlo.add %c_6, %c_7 : tensor + %50 = stablehlo.select %48, %49, %c_4 : tensor, tensor + %51 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %52 = stablehlo.add %c_2, %c_8 : tensor + %53 = stablehlo.select %51, %52, %c_5 : tensor, tensor + %54 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %55 = stablehlo.add %c_2, %c_9 : tensor + %56 = stablehlo.select %54, %55, %c_5 : tensor, tensor + %57 = stablehlo.dynamic_slice %29, %50, %53, %56, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %58 = stablehlo.reshape %57 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %59 = "stablehlo.collective_permute"(%58) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %60 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %61 = "stablehlo.scatter"(%47, %60, %59) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %62 = "stablehlo.collective_permute"(%46) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %63 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %64 = stablehlo.slice %6 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %65 = stablehlo.reshape %64 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %66 = stablehlo.slice %61 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %67 = stablehlo.reshape %66 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %68 = func.call @_where_2(%63, %65, %67) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %69 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %70 = "stablehlo.scatter"(%61, %69, %68) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %71 = stablehlo.dot_general %70, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %72 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %73 = stablehlo.broadcast_in_dim %72, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %74 = stablehlo.add %71, %73 : tensor<2x8x128xf32> + %75 = func.call @relu_3(%74) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %76 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %77 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %78 = stablehlo.add %c_6, %c_7 : tensor + %79 = stablehlo.select %77, %78, %c_4 : tensor, tensor + %80 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %81 = stablehlo.add %c_2, %c_8 : tensor + %82 = stablehlo.select %80, %81, %c_5 : tensor, tensor + %83 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %84 = stablehlo.add %c_2, %c_9 : tensor + %85 = stablehlo.select %83, %84, %c_5 : tensor, tensor + %86 = stablehlo.dynamic_slice %75, %79, %82, %85, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %87 = stablehlo.reshape %86 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %88 = stablehlo.slice %62 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %89 = stablehlo.reshape %88 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %90 = func.call @_where_4(%76, %87, %89) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %91 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %92 = "stablehlo.scatter"(%62, %91, %90) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %93 = func.call @_roll_static_5(%75) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %94 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %95 = stablehlo.add %c_6, %c_7 : tensor + %96 = stablehlo.select %94, %95, %c_4 : tensor, tensor + %97 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %98 = stablehlo.add %c_2, %c_8 : tensor + %99 = stablehlo.select %97, %98, %c_5 : tensor, tensor + %100 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %101 = stablehlo.add %c_2, %c_9 : tensor + %102 = stablehlo.select %100, %101, %c_5 : tensor, tensor + %103 = stablehlo.dynamic_slice %75, %96, %99, %102, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %104 = stablehlo.reshape %103 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %105 = "stablehlo.collective_permute"(%104) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %106 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %107 = "stablehlo.scatter"(%93, %106, %105) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %108 = "stablehlo.collective_permute"(%6) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %109 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %110 = stablehlo.slice %108 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %111 = stablehlo.reshape %110 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %112 = stablehlo.slice %107 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %113 = stablehlo.reshape %112 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %114 = func.call @_where_6(%109, %111, %113) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %115 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %116 = "stablehlo.scatter"(%107, %115, %114) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %117 = stablehlo.dot_general %116, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %118 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %119 = stablehlo.broadcast_in_dim %118, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %120 = stablehlo.add %117, %119 : tensor<2x8x128xf32> + %121 = func.call @relu_7(%120) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %122 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %123 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %124 = stablehlo.add %c_6, %c_7 : tensor + %125 = stablehlo.select %123, %124, %c_4 : tensor, tensor + %126 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %127 = stablehlo.add %c_2, %c_8 : tensor + %128 = stablehlo.select %126, %127, %c_5 : tensor, tensor + %129 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %130 = stablehlo.add %c_2, %c_9 : tensor + %131 = stablehlo.select %129, %130, %c_5 : tensor, tensor + %132 = stablehlo.dynamic_slice %121, %125, %128, %131, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %133 = stablehlo.reshape %132 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %134 = stablehlo.slice %92 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %135 = stablehlo.reshape %134 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %136 = func.call @_where_8(%122, %133, %135) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %137 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %138 = "stablehlo.scatter"(%92, %137, %136) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %139 = func.call @_roll_static_9(%121) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %140 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %141 = stablehlo.add %c_6, %c_7 : tensor + %142 = stablehlo.select %140, %141, %c_4 : tensor, tensor + %143 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %144 = stablehlo.add %c_2, %c_8 : tensor + %145 = stablehlo.select %143, %144, %c_5 : tensor, tensor + %146 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %147 = stablehlo.add %c_2, %c_9 : tensor + %148 = stablehlo.select %146, %147, %c_5 : tensor, tensor + %149 = stablehlo.dynamic_slice %121, %142, %145, %148, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %150 = stablehlo.reshape %149 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %151 = "stablehlo.collective_permute"(%150) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %152 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %153 = "stablehlo.scatter"(%139, %152, %151) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %154 = "stablehlo.collective_permute"(%138) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %155 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %156 = stablehlo.slice %108 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %157 = stablehlo.reshape %156 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %158 = stablehlo.slice %153 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %159 = stablehlo.reshape %158 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %160 = func.call @_where_10(%155, %157, %159) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %161 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %162 = "stablehlo.scatter"(%153, %161, %160) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %163 = stablehlo.dot_general %162, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %164 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %165 = stablehlo.broadcast_in_dim %164, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %166 = stablehlo.add %163, %165 : tensor<2x8x128xf32> + %167 = func.call @relu_11(%166) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %168 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %169 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %170 = stablehlo.add %c_6, %c_7 : tensor + %171 = stablehlo.select %169, %170, %c_4 : tensor, tensor + %172 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %173 = stablehlo.add %c_2, %c_8 : tensor + %174 = stablehlo.select %172, %173, %c_5 : tensor, tensor + %175 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %176 = stablehlo.add %c_2, %c_9 : tensor + %177 = stablehlo.select %175, %176, %c_5 : tensor, tensor + %178 = stablehlo.dynamic_slice %167, %171, %174, %177, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %179 = stablehlo.reshape %178 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %180 = stablehlo.slice %154 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %181 = stablehlo.reshape %180 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %182 = func.call @_where_12(%168, %179, %181) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %183 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %184 = "stablehlo.scatter"(%154, %183, %182) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %185 = func.call @_roll_static_13(%167) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %186 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %187 = stablehlo.add %c_6, %c_7 : tensor + %188 = stablehlo.select %186, %187, %c_4 : tensor, tensor + %189 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %190 = stablehlo.add %c_2, %c_8 : tensor + %191 = stablehlo.select %189, %190, %c_5 : tensor, tensor + %192 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %193 = stablehlo.add %c_2, %c_9 : tensor + %194 = stablehlo.select %192, %193, %c_5 : tensor, tensor + %195 = stablehlo.dynamic_slice %167, %188, %191, %194, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %196 = stablehlo.reshape %195 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %197 = "stablehlo.collective_permute"(%196) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %198 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %199 = "stablehlo.scatter"(%185, %198, %197) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %200 = "stablehlo.collective_permute"(%108) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %201 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %202 = stablehlo.slice %200 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %203 = stablehlo.reshape %202 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %204 = stablehlo.slice %199 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %205 = stablehlo.reshape %204 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %206 = func.call @_where_14(%201, %203, %205) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %207 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %208 = "stablehlo.scatter"(%199, %207, %206) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %209 = stablehlo.dot_general %208, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %210 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %211 = stablehlo.broadcast_in_dim %210, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %212 = stablehlo.add %209, %211 : tensor<2x8x128xf32> + %213 = func.call @relu_15(%212) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %214 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %215 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %216 = stablehlo.add %c_6, %c_7 : tensor + %217 = stablehlo.select %215, %216, %c_4 : tensor, tensor + %218 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %219 = stablehlo.add %c_2, %c_8 : tensor + %220 = stablehlo.select %218, %219, %c_5 : tensor, tensor + %221 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %222 = stablehlo.add %c_2, %c_9 : tensor + %223 = stablehlo.select %221, %222, %c_5 : tensor, tensor + %224 = stablehlo.dynamic_slice %213, %217, %220, %223, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %225 = stablehlo.reshape %224 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %226 = stablehlo.slice %184 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %227 = stablehlo.reshape %226 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %228 = func.call @_where_16(%214, %225, %227) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %229 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %230 = "stablehlo.scatter"(%184, %229, %228) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %231 = func.call @_roll_static_17(%213) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %232 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %233 = stablehlo.add %c_6, %c_7 : tensor + %234 = stablehlo.select %232, %233, %c_4 : tensor, tensor + %235 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %236 = stablehlo.add %c_2, %c_8 : tensor + %237 = stablehlo.select %235, %236, %c_5 : tensor, tensor + %238 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %239 = stablehlo.add %c_2, %c_9 : tensor + %240 = stablehlo.select %238, %239, %c_5 : tensor, tensor + %241 = stablehlo.dynamic_slice %213, %234, %237, %240, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %242 = stablehlo.reshape %241 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %243 = "stablehlo.collective_permute"(%242) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %244 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %245 = "stablehlo.scatter"(%231, %244, %243) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %246 = "stablehlo.collective_permute"(%230) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %247 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %248 = stablehlo.slice %200 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %249 = stablehlo.reshape %248 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %250 = stablehlo.slice %245 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %251 = stablehlo.reshape %250 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %252 = func.call @_where_18(%247, %249, %251) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %253 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %254 = "stablehlo.scatter"(%245, %253, %252) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %255 = stablehlo.dot_general %254, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %256 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %257 = stablehlo.broadcast_in_dim %256, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %258 = stablehlo.add %255, %257 : tensor<2x8x128xf32> + %259 = func.call @relu_19(%258) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %260 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %261 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %262 = stablehlo.add %c_6, %c_7 : tensor + %263 = stablehlo.select %261, %262, %c_4 : tensor, tensor + %264 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %265 = stablehlo.add %c_2, %c_8 : tensor + %266 = stablehlo.select %264, %265, %c_5 : tensor, tensor + %267 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %268 = stablehlo.add %c_2, %c_9 : tensor + %269 = stablehlo.select %267, %268, %c_5 : tensor, tensor + %270 = stablehlo.dynamic_slice %259, %263, %266, %269, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %271 = stablehlo.reshape %270 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %272 = stablehlo.slice %246 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %273 = stablehlo.reshape %272 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %274 = func.call @_where_20(%260, %271, %273) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %275 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %276 = "stablehlo.scatter"(%246, %275, %274) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %277 = func.call @_roll_static_21(%259) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %278 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %279 = stablehlo.add %c_6, %c_7 : tensor + %280 = stablehlo.select %278, %279, %c_4 : tensor, tensor + %281 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %282 = stablehlo.add %c_2, %c_8 : tensor + %283 = stablehlo.select %281, %282, %c_5 : tensor, tensor + %284 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %285 = stablehlo.add %c_2, %c_9 : tensor + %286 = stablehlo.select %284, %285, %c_5 : tensor, tensor + %287 = stablehlo.dynamic_slice %259, %280, %283, %286, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %288 = stablehlo.reshape %287 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %289 = "stablehlo.collective_permute"(%288) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<8x128xf32>) -> tensor<8x128xf32> + %290 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %291 = "stablehlo.scatter"(%277, %290, %289) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %292 = "stablehlo.collective_permute"(%200) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %293 = stablehlo.compare EQ, %10, %c_2, SIGNED : (tensor, tensor) -> tensor + %294 = stablehlo.slice %292 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %295 = stablehlo.reshape %294 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %296 = stablehlo.slice %291 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %297 = stablehlo.reshape %296 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %298 = func.call @_where_22(%293, %295, %297) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %299 = stablehlo.broadcast_in_dim %c_2, dims = [] : (tensor) -> tensor<1xi32> + %300 = "stablehlo.scatter"(%291, %299, %298) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %301 = stablehlo.dot_general %300, %arg10, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<2x128x128xf32>) -> tensor<2x8x128xf32> + %302 = stablehlo.broadcast_in_dim %arg11, dims = [0, 2] : (tensor<2x128xf32>) -> tensor<2x1x128xf32> + %303 = stablehlo.broadcast_in_dim %302, dims = [0, 1, 2] : (tensor<2x1x128xf32>) -> tensor<2x8x128xf32> + %304 = stablehlo.add %301, %303 : tensor<2x8x128xf32> + %305 = func.call @relu_23(%304) : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %306 = stablehlo.compare EQ, %10, %c_3, SIGNED : (tensor, tensor) -> tensor + %307 = stablehlo.compare LT, %c_4, %c_5, SIGNED : (tensor, tensor) -> tensor + %308 = stablehlo.add %c_6, %c_7 : tensor + %309 = stablehlo.select %307, %308, %c_4 : tensor, tensor + %310 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %311 = stablehlo.add %c_2, %c_8 : tensor + %312 = stablehlo.select %310, %311, %c_5 : tensor, tensor + %313 = stablehlo.compare LT, %c_5, %c_5, SIGNED : (tensor, tensor) -> tensor + %314 = stablehlo.add %c_2, %c_9 : tensor + %315 = stablehlo.select %313, %314, %c_5 : tensor, tensor + %316 = stablehlo.dynamic_slice %305, %309, %312, %315, sizes = [1, 8, 128] : (tensor<2x8x128xf32>, tensor, tensor, tensor) -> tensor<1x8x128xf32> + %317 = stablehlo.reshape %316 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %318 = stablehlo.slice %276 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %319 = stablehlo.reshape %318 : (tensor<1x8x128xf32>) -> tensor<8x128xf32> + %320 = func.call @_where_24(%306, %317, %319) : (tensor, tensor<8x128xf32>, tensor<8x128xf32>) -> tensor<8x128xf32> + %321 = stablehlo.broadcast_in_dim %c_3, dims = [] : (tensor) -> tensor<1xi32> + %322 = "stablehlo.scatter"(%276, %321, %320) <{indices_are_sorted = true, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = true}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + stablehlo.return %arg17 : tensor + }) : (tensor<2x8x128xf32>, tensor<1xi32>, tensor<8x128xf32>) -> tensor<2x8x128xf32> + %323 = "stablehlo.collective_permute"(%322) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %324 = "stablehlo.collective_permute"(%323) <{channel_handle = #stablehlo.channel_handle, source_target_pairs = dense<[[0, 4], [4, 0], [1, 5], [5, 1], [2, 6], [6, 2], [3, 7], [7, 3]]> : tensor<8x2xi64>}> : (tensor<2x8x128xf32>) -> tensor<2x8x128xf32> + %325 = stablehlo.dot_general %324, %arg12, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<2x8x128xf32>, tensor<128x8xf32>) -> tensor<2x8x8xf32> + %326 = stablehlo.broadcast_in_dim %arg13, dims = [2] : (tensor<8xf32>) -> tensor<1x1x8xf32> + %327 = stablehlo.broadcast_in_dim %326, dims = [0, 1, 2] : (tensor<1x1x8xf32>) -> tensor<2x8x8xf32> + %328 = stablehlo.add %325, %327 : tensor<2x8x8xf32> + %329 = stablehlo.reshape %328 : (tensor<2x8x8xf32>) -> tensor<16x8xf32> + %330 = stablehlo.subtract %329, %arg15 : tensor<16x8xf32> + %331 = stablehlo.multiply %330, %330 : tensor<16x8xf32> + %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor + %332 = stablehlo.reduce(%331 init: %cst_10) applies stablehlo.add across dimensions = [1] : (tensor<16x8xf32>, tensor) -> tensor<16xf32> + %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor + %333 = stablehlo.reduce(%332 init: %cst_11) applies stablehlo.add across dimensions = [0] : (tensor<16xf32>, tensor) -> tensor + %cst_12 = stablehlo.constant dense<1.600000e+01> : tensor + %334 = stablehlo.divide %333, %cst_12 : tensor + %335 = "stablehlo.all_reduce"(%334) <{channel_handle = #stablehlo.channel_handle, replica_groups = dense<[[0, 4], [1, 5], [2, 6], [3, 7]]> : tensor<4x2xi64>, use_global_device_ids}> ({ + ^bb0(%arg16: tensor, %arg17: tensor): + %337 = stablehlo.add %arg16, %arg17 : tensor + stablehlo.return %337 : tensor + }) : (tensor) -> tensor + %cst_13 = stablehlo.constant dense<2.000000e+00> : tensor + %336 = stablehlo.divide %335, %cst_13 : tensor + sdy.return %336 : tensor + } : (tensor<784x128xf32>, tensor<128xf32>, tensor<4x128x128xf32>, tensor<4x128xf32>, tensor<128x8xf32>, tensor<8xf32>, tensor<32x784xf32>, tensor<32x8xf32>) -> tensor + return %0 : tensor + } + func.func private @relu(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_0(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_1(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_2(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_3(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_4(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_5(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_6(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_7(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_8(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_9(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_10(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_11(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_12(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_13(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_14(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_15(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_16(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_17(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_18(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_19(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_20(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @_roll_static_21(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %0 = stablehlo.slice %arg0 [1:2, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %1 = stablehlo.slice %arg0 [0:1, 0:8, 0:128] : (tensor<2x8x128xf32>) -> tensor<1x8x128xf32> + %2 = stablehlo.concatenate %0, %1, dim = 0 : (tensor<1x8x128xf32>, tensor<1x8x128xf32>) -> tensor<2x8x128xf32> + return %2 : tensor<2x8x128xf32> + } + func.func private @_where_22(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } + func.func private @relu_23(%arg0: tensor<2x8x128xf32>) -> tensor<2x8x128xf32> { + %cst = stablehlo.constant dense<0.000000e+00> : tensor + %0 = stablehlo.broadcast_in_dim %cst, dims = [] : (tensor) -> tensor<2x8x128xf32> + %1 = stablehlo.maximum %arg0, %0 : tensor<2x8x128xf32> + return %1 : tensor<2x8x128xf32> + } + func.func private @_where_24(%arg0: tensor, %arg1: tensor<8x128xf32>, %arg2: tensor<8x128xf32>) -> tensor<8x128xf32> { + %0 = stablehlo.select %arg0, %arg1, %arg2 : tensor, tensor<8x128xf32> + return %0 : tensor<8x128xf32> + } +} + +// CHECK-LABEL @main diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir index 68675eca89..0957b716bd 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_gspmd.mlir @@ -1,6 +1,5 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s -// UNSUPPORTED: true module @jit_loss_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { func.func public @main(%arg0: tensor<784x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg1: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg2: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg3: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg4: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg5: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg6: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg7: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg8: tensor<128x128xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg9: tensor<128xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg10: tensor<128x8xf32> {mhlo.sharding = "{devices=[8,1]<=[8]}"}, %arg11: tensor<8xf32> {mhlo.sharding = "{devices=[8]<=[8]}"}, %arg12: tensor<32x784xf32> {mhlo.sharding = "{devices=[1,8]<=[8]}"}, %arg13: tensor<32x8xf32> {mhlo.sharding = "{devices=[1,8]<=[8]}"}) -> (tensor {jax.result_info = ""}) { @@ -155,3 +154,103 @@ module @jit_loss_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas // CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type diff --git a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir index f2e3b07be8..b8344dac96 100644 --- a/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir +++ b/test/ttmlir/Conversion/StableHLOToTTIR/ccl/e2e_tp_shardy.mlir @@ -1,6 +1,5 @@ // REQUIRES: stablehlo // RUN: ttmlir-opt -split-input-file --stablehlo-to-ttir-pipeline %s | FileCheck %s -// UNSUPPORTED: true module @jit_loss_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas = 1 : i32} { sdy.mesh @mesh = <["x"=1, "y"=8]> @@ -101,3 +100,103 @@ module @jit_loss_tp attributes {mhlo.num_partitions = 8 : i32, mhlo.num_replicas } // CHECK-LABEL @main +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type +// CHECK: "ttir.mesh_shard" +// CHECK-SAME: shard_dims = array +// CHECK-SAME: shard_direction = #tt.shard_direction +// CHECK-SAME: shard_shape = array +// CHECK-SAME: shard_type = #tt.shard_type