Skip to content

Commit

Permalink
fix auto dist
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Jan 16, 2025
1 parent 1ca0154 commit 42e6e46
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,12 @@ private Dictionary<IRType, List<Expr>> VisitLeafArgument(ParameterKind parameter
case (ParameterKind.Attribute, TensorConst e):
updateBuckets(buckets, new[] { e.With() }); // remove all old users.
break;
case (ParameterKind.Attribute, ShapeConst e):
updateBuckets(buckets, new[] { e.With() }); // remove all old users.
break;
case (ParameterKind.Attribute, DimensionConst e):
updateBuckets(buckets, new[] { e.With() }); // remove all old users.
break;
case (ParameterKind.Attribute, None e):
updateBuckets(buckets, new[] { e.With() });
break;
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Evaluator/Tensors/Slice.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ public IValue Visit(IEvaluateContext context, Slice sl)
}
else if (inputValue is TensorValue inputTValue)
{
var input = inputTValue.AsTensor().Cast<long>().ToOrtTensor();
var input = inputTValue.AsTensor().ToOrtTensor();
var begins = beginsTensor.Cast<long>().ToOrtTensor();
var ends = endsTensor.Cast<long>().ToOrtTensor();
var axes = axesTensor.Cast<long>().ToOrtTensor();
Expand Down
16 changes: 16 additions & 0 deletions src/Nncase.Tests/Core/UnitTestTypeInfer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,22 @@ public void TestReshapeInfer()
Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 896 }), resultb);
}

[Fact]
public void TestConcatInfer()
{
var seq_len = new Var("seq_len", new TensorType(DataTypes.Int64, Shape.Scalar));
var hist_len = new Var("his_len", new TensorType(DataTypes.Int64, Shape.Scalar));
var lhs = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, seq_len, 2, 64 }));
var rhs = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, hist_len, 2, 64 }));
var reshape = Concat(new IR.Tuple(new[] { lhs, rhs }), 1);
var result = reshape.CheckedType;
Assert.IsType<TensorType>(result);
Assert.IsType<Call>(((TensorType)result).Shape[1].Value);
var call = (Call)((TensorType)result).Shape[1].Value;
Assert.Equal(seq_len, call.Arguments[0]);
Assert.Equal(hist_len, call.Arguments[1]);
}

private void CheckInferShape(Expr expr, params Dimension[] shapeDimensions)
{
CheckInferShape(expr, new Shape(shapeDimensions));
Expand Down

0 comments on commit 42e6e46

Please sign in to comment.