Skip to content

Commit

Permalink
fix distributed graph partition
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 17, 2025
1 parent 497e547 commit 5147404
Showing 1 changed file with 8 additions and 16 deletions.
24 changes: 8 additions & 16 deletions modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,27 +64,19 @@ bool CheckField(Expr f)
}

bool isSupport = false;
switch (arg.Edge.Target.Expr)
switch (arg.Edge.Source.Expr, arg.Edge.Target.Expr)
{
case Call call:
if (call.Target is IR.CPU.Boxing { NewType: TensorType } && call.Arguments[0].CheckedType is DistributedType)
case (Call callee, Call caller):
switch (callee.CheckedType, caller.CheckedType)
{
isSupport = true;
}
else if (call.Target is IR.CPU.Boxing { NewType: DistributedType })
{
if (arg.Edge.Source.Expr.CheckedType is not TensorType)
{
case (DistributedType, TensorType) when caller.Target is IR.CPU.Boxing:
case (DistributedType, DistributedType):
isSupport = true;
}
}
else if (call.CheckedType is DistributedType)
{
isSupport = true;
break;
}

break;
case IR.Tuple tp:
case (Call field, IR.Tuple tp):
isSupport = tp.Fields.AsValueEnumerable().All(f => f is Call c && CheckField(c)) ? true : false;
break;
default:
Expand Down Expand Up @@ -141,7 +133,7 @@ protected override Expr OnComplexCluster(ClusteredBidirectionalGraph<ExprVertex,
var argumentDict = new Dictionary<Var, Expr>(ReferenceEqualityComparer.Instance);
foreach (var (pre, post) in pairs)
{
if (pre is Const)
if (pre is not (Call or Var))
{
continue;
}
Expand Down

0 comments on commit 5147404

Please sign in to comment.