Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change Manual shard_type to Identity and remove unnecessary mesh_shar… #2373

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 69 additions & 44 deletions include/ttmlir/Conversion/StableHLOToTTIR/ShardingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/PatternMatch.h"
#include "shardy/dialect/sdy/ir/dialect.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"

Expand All @@ -28,53 +29,36 @@ class MeshSharding {
llvm::Expected<bool>
convertGSPMDShardingToMeshSharding(StringRef shardingStr);

// Check and update arg sharding attribute and determine if mesh_shard op
// needs to be created or not.
bool checkAndUpdateGSPMDArgSharding(mlir::PatternRewriter &rewriter,
mlir::stablehlo::CustomCallOp srcOp,
mlir::StringAttr shardingAttr);

// Check and update ret sharding attribute and determine if mesh_shard op
// needs to be created or not.
bool checkAndUpdateGSPMDRetSharding(mlir::PatternRewriter &rewriter,
mlir::stablehlo::CustomCallOp srcOp,
mlir::StringAttr shardingAttr);

// Convert sdy.sharding to meshSharding.
llvm::Expected<bool>
convertSdyShardingToMeshSharding(mlir::sdy::TensorShardingAttr sdySharding,
mlir::sdy::MeshAttr mesh,
mlir::tt::MeshShardDirection direction);

// Check and update function arg sharding
template <typename AttrType>
void checkAndUpdateFuncArgSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t argNum,
AttrType shardingAttr,
llvm::StringRef argShardingStrRef) {
if (auto argShardingAttr =
funcOp.getArgAttrOfType<AttrType>(argNum, argShardingStrRef)) {
if (argShardingAttr == shardingAttr) {
setDummyShardingOp();
rewriter.modifyOpInPlace(
funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); });
} else {
llvm_unreachable(
"MeshSharding operation and function argument shardings "
"are different.");
}
}
}
// Check and update arg sharding attribute and determine if
// mesh_shard op needs to be created or not.
bool checkAndUpdateShardyArgSharding(
mlir::PatternRewriter &rewriter, mlir::func::FuncOp funcOp,
mlir::Value argOperand, mlir::sdy::TensorShardingAttr shardingAttr);

// Check and update function ret sharding
template <typename AttrType>
void checkAndUpdateFuncReturnSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp,
uint64_t retNum, AttrType shardingAttr,
llvm::StringRef retShardingStrRef) {
if (auto retShardingAttr =
funcOp.getResultAttrOfType<AttrType>(retNum, retShardingStrRef)) {
if (retShardingAttr == shardingAttr) {
setDummyShardingOp();
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeResultAttr(
retNum,
mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef));
});
} else {
llvm_unreachable("MeshSharding operation and function return shardings "
"are different.");
}
}
}
// Check and update ret sharding attribute and determine if mesh_shard op
// needs to be created or not.
bool
checkAndUpdateShardyRetSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t retIdx,
mlir::sdy::TensorShardingAttr shardingAttr);

// Getter functions.
mlir::tt::MeshShardDirection getShardDirection() const {
Expand Down Expand Up @@ -103,21 +87,62 @@ class MeshSharding {
meshShape = llvm::SmallVector<int64_t>{-1};
}

// Force dummy sharding op by setting shard_type to manual. The mesh_shard op
// will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Manual; }
// Force dummy sharding op by setting shard_type to identity. The mesh_shard
// op will be ignored at runtime by simply copying input tensor to output.
void setDummyShardingOp() { shardType = mlir::tt::MeshShardType::Identity; }

private:
mlir::tt::MeshShardDirection shardDirection =
mlir::tt::MeshShardDirection::ShardToFull;
mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Manual;
mlir::tt::MeshShardType shardType = mlir::tt::MeshShardType::Identity;
llvm::SmallVector<int64_t> shardShape{-1};
llvm::SmallVector<int64_t> shardDims{-1};
llvm::SmallVector<int64_t> meshShape{-1};
llvm::SmallVector<int64_t> deviceIds{-1};
bool lastTileDimReplicate = false;
};

// Remove arg sharding and return true if it is found, otherwise return false.
template <typename AttrType>
bool checkAndRemoveFuncArgSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp, uint64_t argNum,
AttrType shardingAttr,
llvm::StringRef argShardingStrRef) {
if (auto argShardingAttr =
funcOp.getArgAttrOfType<AttrType>(argNum, argShardingStrRef)) {
if (argShardingAttr == shardingAttr) {
rewriter.modifyOpInPlace(
funcOp, [&]() { funcOp.removeArgAttr(argNum, argShardingStrRef); });
return true;
}
llvm_unreachable("MeshSharding operation and function argument shardings "
"are different.");
}
return false;
}

// Remove ret sharding and return true if it is found, otherwise return false.
template <typename AttrType>
bool checkAndRemoveFuncReturnSharding(mlir::PatternRewriter &rewriter,
mlir::func::FuncOp funcOp,
uint64_t retIdx, AttrType shardingAttr,
llvm::StringRef retShardingStrRef) {
if (auto retShardingAttr =
funcOp.getResultAttrOfType<AttrType>(retIdx, retShardingStrRef)) {
if (retShardingAttr == shardingAttr) {
rewriter.modifyOpInPlace(funcOp, [&]() {
funcOp.removeResultAttr(
retIdx,
mlir::StringAttr::get(rewriter.getContext(), retShardingStrRef));
});
return true;
}
llvm_unreachable("MeshSharding operation and function return shardings "
"are different.");
}
return false;
}

// Sharding related string definitions from open-xla
// https://github.com/openxla/xla/blob/main/xla/service/spmd/shardy/constants.h

Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -164,14 +164,14 @@ def TT_MeshShardDirection: I32EnumAttr<"MeshShardDirection", "TT MeshShardDirect
let cppNamespace = "::mlir::tt";
}

def TT_MeshShardType_Manual : I32EnumAttrCase<"Manual", 0, "manual">;
def TT_MeshShardType_Identity : I32EnumAttrCase<"Identity", 0, "identity">;
def TT_MeshShardType_Replicate : I32EnumAttrCase<"Replicate", 1, "replicate">;
def TT_MeshShardType_Maximal : I32EnumAttrCase<"Maximal", 2, "maximal">;
def TT_MeshShardType_Devices : I32EnumAttrCase<"Devices", 3, "devices">;

def TT_MeshShardType: I32EnumAttr<"MeshShardType", "TT MeshShardType",
[
TT_MeshShardType_Manual,
TT_MeshShardType_Identity,
TT_MeshShardType_Replicate,
TT_MeshShardType_Maximal,
TT_MeshShardType_Devices,
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,14 @@ def TT_MeshShardDirectionAttr : EnumAttr<TT_Dialect, TT_MeshShardDirection, "sha
}

def TT_MeshShardTypeAttr : EnumAttr<TT_Dialect, TT_MeshShardType, "shard_type"> {
let summary = "MeshShard shard_type attribute in TT dialect";
let description = [{
Define sharded tensor data of mesh_shard op.
- Identity: input and output tensors are pre-sharded (same data) and no sharding is required.
- Replicate: all of the devices has full tensor (same data).
- Maximal: one or part of the devcices has full tensor (same data).
- Devices: all or part of the devices has sharded (partial) tensor (different data).
}];
let assemblyFormat = "`<` $value `>`";
}

Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,8 @@ def TTNN_MeshShardOp: TTNN_Op<"mesh_shard"> {

let extraClassDeclaration = [{
wa::TTNNOperandsWorkarounds getOperandsWorkarounds() {
return wa::TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds();
mlir::tt::ttnn::MeshShardOp op = mlir::cast<mlir::tt::ttnn::MeshShardOp>(this->getOperation());
return wa::TTNNOperandsWorkaroundsFactory::createMeshShardOpOperandsWorkarounds(op.getShardType());
}
}];

Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/IR/TTNNWorkarounds.h
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ class TTNNOperandsWorkaroundsFactory {
createFullOpOperandsWorkarounds(RankedTensorType outputType);

// Create workarounds for mesh shard op operands.
static TTNNOperandsWorkarounds createMeshShardOpOperandsWorkarounds();
static TTNNOperandsWorkarounds
createMeshShardOpOperandsWorkarounds(mlir::tt::MeshShardType shardType);

// Create workarounds for concat op operands.
static TTNNOperandsWorkarounds
Expand Down
2 changes: 1 addition & 1 deletion include/ttmlir/Target/TTNN/types.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ enum MeshShardDirection: uint32 {
}

enum MeshShardType: uint32 {
Manual,
Identity,
Replicate,
Maximal,
Devices,
Expand Down
Loading
Loading