From 42e6e469d218ee23bf6e145bef291b6c252c5bee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=83=91=E5=90=AF=E8=88=AA?= <597323109@qq.com> Date: Thu, 16 Jan 2025 16:19:30 +0800 Subject: [PATCH] fix auto dist --- .../Passes/Distributed/AutoDistributed.cs | 6 ++++++ src/Nncase.Evaluator/Tensors/Slice.cs | 2 +- src/Nncase.Tests/Core/UnitTestTypeInfer.cs | 16 ++++++++++++++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index dbc8de7672..8d7919a8f6 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -496,6 +496,12 @@ private Dictionary> VisitLeafArgument(ParameterKind parameter case (ParameterKind.Attribute, TensorConst e): updateBuckets(buckets, new[] { e.With() }); // remove all old users. break; + case (ParameterKind.Attribute, ShapeConst e): + updateBuckets(buckets, new[] { e.With() }); // remove all old users. + break; + case (ParameterKind.Attribute, DimensionConst e): + updateBuckets(buckets, new[] { e.With() }); // remove all old users. + break; case (ParameterKind.Attribute, None e): updateBuckets(buckets, new[] { e.With() }); break; diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs index b1763790a3..93b58c7eac 100644 --- a/src/Nncase.Evaluator/Tensors/Slice.cs +++ b/src/Nncase.Evaluator/Tensors/Slice.cs @@ -64,7 +64,7 @@ public IValue Visit(IEvaluateContext context, Slice sl) } else if (inputValue is TensorValue inputTValue) { - var input = inputTValue.AsTensor().Cast().ToOrtTensor(); + var input = inputTValue.AsTensor().ToOrtTensor(); var begins = beginsTensor.Cast().ToOrtTensor(); var ends = endsTensor.Cast().ToOrtTensor(); var axes = axesTensor.Cast().ToOrtTensor(); diff --git a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs index 0e290ae13c..d04efd5857 100644 --- a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs +++ b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs @@ -343,6 +343,22 @@ public void TestReshapeInfer() Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 896 }), resultb); } + [Fact] + public void TestConcatInfer() + { + var seq_len = new Var("seq_len", new TensorType(DataTypes.Int64, Shape.Scalar)); + var hist_len = new Var("his_len", new TensorType(DataTypes.Int64, Shape.Scalar)); + var lhs = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, seq_len, 2, 64 })); + var rhs = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, hist_len, 2, 64 })); + var reshape = Concat(new IR.Tuple(new[] { lhs, rhs }), 1); + var result = reshape.CheckedType; + Assert.IsType(result); + Assert.IsType(((TensorType)result).Shape[1].Value); + var call = (Call)((TensorType)result).Shape[1].Value; + Assert.Equal(seq_len, call.Arguments[0]); + Assert.Equal(hist_len, call.Arguments[1]); + } + private void CheckInferShape(Expr expr, params Dimension[] shapeDimensions) { CheckInferShape(expr, new Shape(shapeDimensions));