Skip to content

Commit

Permalink
Change Manual shard_type to Identical and remove unnecessary mesh_sha…
Browse files Browse the repository at this point in the history
…rd op
  • Loading branch information
wooseokTT committed Mar 6, 2025
1 parent 5af3d21 commit 17be714
Show file tree
Hide file tree
Showing 25 changed files with 1,924 additions and 319 deletions.
113 changes: 69 additions & 44 deletions include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

Expand All @@ -28,53 +29,36 @@ class MeshSharding {
llvm::Expected<bool>
convertGSPMDShardingToMeshSharding(StringRef shardingStr);

// Check and update arg sharding attribute and determine if mesh_shard op
// needs to be created or not.
bool checkAndUpdateGSPMDArgSharding(mlir::PatternRewriter &rewriter,
mlir::stablehlo::CustomCallOp srcOp,
mlir::StringAttr shardingAttr);

// Check and update ret sharding attribute and determine if mesh_shard op
// needs to be created or not.
bool checkAndUpdateGSPMDRetSharding(mlir::PatternRewriter &rewriter,
mlir::stablehlo::CustomCallOp srcOp,
mlir::StringAttr shardingAttr);

// Convert sdy.sharding to meshSharding.
llvm::Expected<bool>
convertSdyShardingToMeshSharding(mlir::sdy::TensorShardingAttr sdySharding,
mlir::sdy::MeshAttr mesh,
mlir::tt::MeshShardDirection direction);

// Check and update function arg sharding
template <typename AttrType>
void checkAndUpdateFuncArgSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t argNum,
AttrType shardingAttr,
llvm::StringRef argShardingStrRef) {
if (auto argShardingAttr =
funcOp.getArgAttrOfType<AttrType>(argNum, argShardingStrRef)) {
if (argShardingAttr == shardingAttr) {
setDummyShardingOp();
rewriter.modifyOpInPlace(
funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); });
} else {
llvm_unreachable(
"MeshSharding operation and function argument shardings "
"are different.");
}
}
}
// Check and update arg sharding attribute and determine if
// mesh_shard op needs to be created or not.
bool checkAndUpdateShardyArgSharding(
mlir::PatternRewriter &rewriter, mlir::func::FuncOp funcOp,
mlir::Value argOperand, mlir::sdy::TensorShardingAttr shardingAttr);

// Check and update function ret sharding
template <typename AttrType>
void checkAndUpdateFuncReturnSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp,
uint64_t retNum, AttrType shardingAttr,
llvm::StringRef retShardingStrRef) {
if (auto retShardingAttr =
funcOp.getResultAttrOfType<AttrType>(retNum, retShardingStrRef)) {
if (retShardingAttr == shardingAttr) {
setDummyShardingOp();
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeResultAttr(
retNum,
mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef));
});
} else {
llvm_unreachable("MeshSharding operation and function return shardings "
"are different.");
}
}
}
// Check and update ret sharding attribute and determine if mesh_shard op
// needs to be created or not.
bool
checkAndUpdateShardyRetSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t retIdx,
mlir::sdy::TensorShardingAttr shardingAttr);

// Getter functions.
mlir::tt::MeshShardDirection getShardDirection() const {
Expand Down Expand Up @@ -103,21 +87,62 @@ class MeshSharding {
meshShape = llvm::SmallVector<int64_t>{-1};
}

// Force dummy sharding op by setting shard_type to manual. The mesh_shard op
// will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; }
// Force dummy sharding op by setting shard_type to identity. The mesh_shard
// op will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Identity; }

private:
mlir::tt::MeshShardDirection shardDirection =
mlir::tt::MeshShardDirection::ShardToFull;
mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Manual;
mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Identity;
llvm::SmallVector<int64_t> shardShape{-1};
llvm::SmallVector<int64_t> shardDims{-1};
llvm::SmallVector<int64_t> meshShape{-1};
llvm::SmallVector<int64_t> deviceIds{-1};
bool lastTileDimReplicate = false;
};

// Remove arg sharding and return true if it is found, otherwise return false.
template <typename AttrType>
bool checkAndRemoveFuncArgSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t argNum,
AttrType shardingAttr,
llvm::StringRef argShardingStrRef) {
if (auto argShardingAttr =
funcOp.getArgAttrOfType<AttrType>(argNum, argShardingStrRef)) {
if (argShardingAttr == shardingAttr) {
rewriter.modifyOpInPlace(
funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); });
return true;
}
llvm_unreachable("MeshSharding operation and function argument shardings "
"are different.");
}
return false;
}

// Remove ret sharding and return true if it is found, otherwise return false.
template <typename AttrType>
bool checkAndRemoveFuncReturnSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp,
uint64_t retIdx, AttrType shardingAttr,
llvm::StringRef retShardingStrRef) {
if (auto retShardingAttr =
funcOp.getResultAttrOfType<AttrType>(retIdx, retShardingStrRef)) {
if (retShardingAttr == shardingAttr) {
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeResultAttr(
retIdx,
mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef));
});
return true;
}
llvm_unreachable("MeshSharding operation and function return shardings "
"are different.");
}
return false;
}

// Sharding related string definitions from open-xla
// https://github.com/openxla/xla/blob/main/xla/service/spmd/shardy/constants.h

Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@ def TT_MeshShardDirection: I32EnumAttr<"MeshShardDirection", "TT MeshShardDirect
let cppNamespace = "::mlir::tt";
}

def TT_MeshShardType_Manual : I32EnumAttrCase<"Manual", 0, "manual">;
def TT_MeshShardType_Identity : I32EnumAttrCase<"Identity", 0, "identity">;
def TT_MeshShardType_Replicate : I32EnumAttrCase<"Replicate", 1, "replicate">;
def TT_MeshShardType_Maximal : I32EnumAttrCase<"Maximal", 2, "maximal">;
def TT_MeshShardType_Devices : I32EnumAttrCase<"Devices", 3, "devices">;

def TT_MeshShardType: I32EnumAttr<"MeshShardType", "TT MeshShardType",
[
TT_MeshShardType_Manual,
TT_MeshShardType_Identity,
TT_MeshShardType_Replicate,
TT_MeshShardType_Maximal,
TT_MeshShardType_Devices,
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,14 @@ def TT_MeshShardDirectionAttr : EnumAttr<TT_Dialect, TT_MeshShardDirection, "sha
}

def TT_MeshShardTypeAttr : EnumAttr<TT_Dialect, TT_MeshShardType, "shard_type"> {
let summary = "MeshShard shard_type attribute in TT dialect";
let description = [{
Define sharded tensor data of mesh_shard op.
- Identity: input and output tensors are pre-sharded (same data) and no sharding is required.
- Replicate: all of the devices has full tensor (same data).
- Maximal: one or part of the devcices has full tensor (same data).
- Devices: all or part of the devices has sharded (partial) tensor (different data).
}];
let assemblyFormat = "`<` $value `>`";
}

Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,8 @@ def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {

let extraClassDeclaration = [{
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds();
mlir::tt::ttnn::MeshShardOp op = mlir::cast<mlir::tt::ttnn::MeshShardOp>(this->getOperation());
return wa::TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds(op.getShardType());
}
}];

Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class TTNNOperandsWorkaroundsFactory {
createFullOpOperandsWorkarounds(RankedTensorType outputType);

// Create workarounds for mesh shard op operands.
static TTNNOperandsWorkarounds createMeshShardOpOperandsWorkarounds();
static TTNNOperandsWorkarounds
createMeshShardOpOperandsWorkarounds(mlir::tt::MeshShardType shardType);

// Create workarounds for concat op operands.
static TTNNOperandsWorkarounds
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ enum MeshShardDirection: uint32 {
}

enum MeshShardType: uint32 {
Manual,
Identity,
Replicate,
Maximal,
Devices,
Expand Down
Loading

0 comments on commit 17be714

Please sign in to comment.