diff --git a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs index 8d7919a8f6..dbc8de7672 100644 --- a/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs +++ b/modules/Nncase.Modules.CPU/Passes/Distributed/AutoDistributed.cs @@ -496,12 +496,6 @@ private Dictionary> VisitLeafArgument(ParameterKind parameter case (ParameterKind.Attribute, TensorConst e): updateBuckets(buckets, new[] { e.With() }); // remove all old users. break; - case (ParameterKind.Attribute, ShapeConst e): - updateBuckets(buckets, new[] { e.With() }); // remove all old users. - break; - case (ParameterKind.Attribute, DimensionConst e): - updateBuckets(buckets, new[] { e.With() }); // remove all old users. - break; case (ParameterKind.Attribute, None e): updateBuckets(buckets, new[] { e.With() }); break; diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs index 2f1f435055..02750c92b2 100644 --- a/src/Nncase.Compiler/Compiler.cs +++ b/src/Nncase.Compiler/Compiler.cs @@ -144,6 +144,7 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); p.Add(); p.Add(); p.Add(); diff --git a/src/Nncase.Core/CostModel/Cost.cs b/src/Nncase.Core/CostModel/Cost.cs index c45c37395b..a0d9842299 100644 --- a/src/Nncase.Core/CostModel/Cost.cs +++ b/src/Nncase.Core/CostModel/Cost.cs @@ -202,7 +202,7 @@ public static UInt128 GetMemoryAccess(IRType type) { return type switch { - TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * t.DType.SizeInBytes), + TensorType t => (UInt128)t.Shape.ProdWithDynamicAsOne() * (UInt128)t.DType.SizeInBytes, TupleType t => t.Fields.Sum(GetMemoryAccess), DistributedType t => GetMemoryAccess(Utilities.DistributedUtility.GetDividedTensorType(t)), _ => 0, @@ -218,7 +218,7 @@ public static UInt128 GetFakeMemoryAccess(IRType type, uint bits) { return type switch { - TensorType t => (UInt128)Math.Ceiling((float)t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * t.DType.SizeInBytes * bits / 8), + TensorType t => (UInt128)Math.Ceiling((float)t.Shape.ProdWithDynamicAsOne() * t.DType.SizeInBytes * bits / 8), TupleType t => t.Fields.Sum(x => GetFakeMemoryAccess(x, bits)), _ => 0, }; @@ -228,7 +228,7 @@ public static UInt128 GetCPUCycles(IRType type, double cyclesPerElement = 1) { return type switch { - TensorType t => (UInt128)(t.Shape.Aggregate(1D, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)) * cyclesPerElement), + TensorType t => (UInt128)(t.Shape.ProdWithDynamicAsOne() * cyclesPerElement), TupleType t => t.Fields.Sum(GetMemoryAccess), DistributedType t => GetCPUCycles(Utilities.DistributedUtility.GetDividedTensorType(t)), _ => 0, diff --git a/src/Nncase.Core/Evaluator/Metric.cs b/src/Nncase.Core/Evaluator/Metric.cs index 052ea5443f..e393e7914d 100644 --- a/src/Nncase.Core/Evaluator/Metric.cs +++ b/src/Nncase.Core/Evaluator/Metric.cs @@ -189,7 +189,7 @@ public static UInt128 GetFLOPs(IRType type, long scale = 1) { return type switch { - TensorType t => (UInt128)t.Shape.Aggregate(scale, (acc, x) => acc * (x.IsFixed ? x.FixedValue : 1)), + TensorType t => (UInt128)t.Shape.ProdWithDynamicAsOne(scale), TupleType t => t.Fields.Sum(f => GetFLOPs(f, scale)), _ => 0, }; diff --git a/src/Nncase.Core/IR/Const.cs b/src/Nncase.Core/IR/Const.cs index 071631ea1f..fc83908cc8 100644 --- a/src/Nncase.Core/IR/Const.cs +++ b/src/Nncase.Core/IR/Const.cs @@ -156,10 +156,6 @@ public static Const FromValue(IValue value) : new TensorConst(tv.AsTensor()); case TupleValue tpv: return new TupleConst(tpv); - case ShapeValue sv: - return new ShapeConst(sv.Dimensions.ToArray()); - case DimensionValue dv: - return new DimensionConst(dv.Dimension); default: throw new ArgumentOutOfRangeException(nameof(value)); } diff --git a/src/Nncase.Core/IR/Dimension.cs b/src/Nncase.Core/IR/Dimension.cs index d818fab574..a41e4b0232 100644 --- a/src/Nncase.Core/IR/Dimension.cs +++ b/src/Nncase.Core/IR/Dimension.cs @@ -9,237 +9,233 @@ using Nncase.Passes; using Nncase.Passes.Mutators; -namespace Nncase.IR +namespace Nncase.IR; + +/// +/// Dimension kind. +/// +public enum DimensionKind : byte { /// - /// Dimension kind. + /// Dynamic dimension. /// - public enum DimensionKind : byte - { - /// - /// Unknown dimension. - /// - Unknown, - - /// - /// Fixed dimesnion. - /// - Fixed, - - /// - /// Used for shape pattern. - /// - Any, - } + Dynamic, /// - /// Shape dimension. + /// Fixed dimesnion. /// - public sealed class Dimension : IEquatable - { - public static readonly Dimension Any = new Dimension(); + Fixed, + + /// + /// Used for shape pattern. + /// + Unknown, +} + +/// +/// Shape dimension. +/// +public struct Dimension : IEquatable +{ + public static readonly Dimension Unknown = new Dimension(None.Default); - private readonly long? _fixedValue; - private readonly Expr? _exprValue; + private readonly long? _fixedValue; + private readonly Expr? _exprValue; - /// - /// Initializes a new instance of the class. - /// - /// Dimension value. - public Dimension(long value) + /// + /// Initializes a new instance of the struct. + /// + /// Dimension value. + public Dimension(long value) + { + Kind = DimensionKind.Fixed; + _fixedValue = value; + } + + public Dimension(Expr value) + { + value = CompilerServices.FastSimplifyForDimension(value); + if (value is TensorConst tc) { Kind = DimensionKind.Fixed; - _fixedValue = value; + _fixedValue = tc.Value.ToScalar(); } - - public Dimension(Expr value) + else if (value is None) { - value = CompilerServices.FastSimplifyForDimension(value); - if (value is TensorConst tc) - { - Kind = DimensionKind.Fixed; - _fixedValue = tc.Value.ToScalar(); - } - else - { - Kind = DimensionKind.Unknown; - _exprValue = value; - } + Kind = DimensionKind.Unknown; + _exprValue = None.Default; } - - private Dimension() + else { - Kind = DimensionKind.Any; - _exprValue = new Var("Any", DataTypes.Int64); + Kind = DimensionKind.Dynamic; + _exprValue = value; } + } - /// - /// Gets kind. - /// - public DimensionKind Kind { get; } + /// + /// Gets kind. + /// + public DimensionKind Kind { get; } - /// - /// Gets value. - /// - public Expr Value => _exprValue ?? _fixedValue!.Value; + /// + /// Gets value. + /// + public Expr Value => _exprValue ?? _fixedValue!.Value; - /// - /// Gets FixedValue. - /// - public long FixedValue - { - get => _fixedValue ?? - throw new InvalidOperationException("Only Can Get It When Shape Is Fixed !"); - } + /// + /// Gets FixedValue. + /// + public long FixedValue + { + get => _fixedValue ?? + throw new InvalidOperationException("Only Can Get It When Shape Is Fixed !"); + } - /// - /// Gets a value indicating whether unknown. - /// - public bool IsUnknown => Kind is DimensionKind.Unknown or DimensionKind.Any; - - /// - /// Gets a value indicating whether fixed. - /// - public bool IsFixed => Kind == DimensionKind.Fixed; - - public bool IsAny => Kind == DimensionKind.Any; - - /// - /// Convert to a fixed . - /// - /// Dimension value. - public static implicit operator Dimension(long value) => new(value); - - /// - /// Convert to a expression. - /// - /// Dimension value. - public static implicit operator Dimension(Expr value) => value switch - { - DimensionConst dc => dc.Value, - _ => new(value), - }; + /// + /// Gets a value indicating whether dynamic. + /// + public bool IsDynamic => Kind is DimensionKind.Dynamic; - public static bool operator ==(Dimension left, Dimension right) - { - return left.Equals(right); - } + /// + /// Gets a value indicating whether fixed. + /// + public bool IsFixed => Kind == DimensionKind.Fixed; - public static bool operator !=(Dimension left, Dimension right) - { - return !(left == right); - } + public bool IsUnknown => Kind == DimensionKind.Unknown; - public static Dimension operator +(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch - { - (true, true) => lhs.FixedValue + rhs.FixedValue, - (true, _) when lhs.FixedValue == 0 => rhs, - (_, true) when rhs.FixedValue == 0 => lhs, - (_, _) => new Dimension(lhs.Value + rhs.Value), - }; + /// + /// Convert to a fixed . + /// + /// Dimension value. + public static implicit operator Dimension(long value) => new(value); - public static Dimension operator +(Dimension lhs, int rhs) => lhs.IsFixed ? lhs.FixedValue + rhs : new Dimension(lhs.Value + rhs); + /// + /// Convert to a expression. + /// + /// Dimension value. + public static implicit operator Dimension(Expr value) => value switch + { + TensorConst dc => new(dc.Value.ToScalar()), + _ => new(value), + }; - public static Dimension operator -(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch - { - (true, true) => lhs.FixedValue - rhs.FixedValue, - (_, true) when rhs.FixedValue == 0 => lhs, - (_, _) => new Dimension(lhs.Value - rhs.Value), - }; + public static bool operator ==(Dimension left, Dimension right) + { + return left.Equals(right); + } - public static Dimension operator *(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch - { - (true, true) => lhs.FixedValue * rhs.FixedValue, - (true, _) when lhs.FixedValue == 1 => rhs, - (_, true) when rhs.FixedValue == 1 => lhs, - (_, _) => new Dimension(lhs.Value * rhs.Value), - }; + public static bool operator !=(Dimension left, Dimension right) + { + return !(left == right); + } - public static Dimension operator /(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch - { - (true, true) => lhs.FixedValue / rhs.FixedValue, - (_, _) => new Dimension(lhs.Value / rhs.Value), - }; + public static Dimension operator +(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch + { + (true, true) => lhs.FixedValue + rhs.FixedValue, + (true, _) when lhs.FixedValue == 0 => rhs, + (_, true) when rhs.FixedValue == 0 => lhs, + (_, _) => new Dimension(lhs.Value + rhs.Value), + }; - public static Dimension Abs(Dimension value) - { - if (value.IsFixed) - { - return System.Math.Abs(value.FixedValue); - } + public static Dimension operator +(Dimension lhs, int rhs) => lhs.IsFixed ? lhs.FixedValue + rhs : new Dimension(lhs.Value + rhs); - return value.Value.Metadata.Range?.Min >= 0 ? value.Value : IR.F.Math.Abs(value.Value); - } + public static Dimension operator -(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch + { + (true, true) => lhs.FixedValue - rhs.FixedValue, + (_, true) when rhs.FixedValue == 0 => lhs, + (_, _) => new Dimension(lhs.Value - rhs.Value), + }; - public static Dimension Clamp(Dimension value, Dimension min, Dimension max) - { - if (value.IsFixed && min.IsFixed && max.IsFixed) - { - return System.Math.Clamp(value.FixedValue, min.FixedValue, max.FixedValue); - } + public static Dimension operator *(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch + { + (true, true) => lhs.FixedValue * rhs.FixedValue, + (true, _) when lhs.FixedValue == 1 => rhs, + (_, true) when rhs.FixedValue == 1 => lhs, + (_, _) => new Dimension(lhs.Value * rhs.Value), + }; - return IR.F.Math.Clamp(value.Value, min.Value, max.Value); - } + public static Dimension operator /(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch + { + (true, true) => lhs.FixedValue / rhs.FixedValue, + (_, _) => new Dimension(lhs.Value / rhs.Value), + }; - public static Dimension CeilDiv(Dimension lhs, Dimension rhs) + public static Dimension Abs(Dimension value) + { + if (value.IsFixed) { - if (lhs.IsFixed && rhs.IsFixed) - { - return (lhs.FixedValue + rhs.FixedValue - 1) / rhs.FixedValue; - } - - return IR.F.Math.CeilDiv(lhs.Value, rhs.Value); + return System.Math.Abs(value.FixedValue); } - // public static Dimension Unknown(string? name = null) => new Dimension(name is null ? new Var(DataTypes.Int64) : new Var(name, DataTypes.Int64)); - - public static Dimension Unknown(string? name = null) => Any; + return value.Value.Metadata.Range?.Min >= 0 ? value.Value : IR.F.Math.Abs(value.Value); + } - /// - public override string ToString() + public static Dimension Clamp(Dimension value, Dimension min, Dimension max) + { + if (value.IsFixed && min.IsFixed && max.IsFixed) { - return Value?.ToString() ?? "?"; + return System.Math.Clamp(value.FixedValue, min.FixedValue, max.FixedValue); } - /// - public override bool Equals(object? obj) - { - return obj is Dimension dimension && Equals(dimension); - } + return IR.F.Math.Clamp(value.Value, min.Value, max.Value); + } - /// - public bool Equals(Dimension? other) + public static Dimension CeilDiv(Dimension lhs, Dimension rhs) + { + if (lhs.IsFixed && rhs.IsFixed) { - return other is not null && (Kind, other.Kind) switch - { - (DimensionKind.Any, DimensionKind.Any) => true, - (DimensionKind.Unknown, DimensionKind.Unknown) => Value == other.Value, - (DimensionKind.Fixed, DimensionKind.Fixed) => FixedValue == other.FixedValue, - (_, _) => false, - }; + return (lhs.FixedValue + rhs.FixedValue - 1) / rhs.FixedValue; } - /// - public override int GetHashCode() - { - return IsFixed ? HashCode.Combine(Kind, FixedValue) : HashCode.Combine(Kind, Value); - } + return IR.F.Math.CeilDiv(lhs.Value, rhs.Value); + } - public bool HasFixedValue(Predicate predicate) - { - return IsFixed && predicate(FixedValue); - } + /// + public override string ToString() => Kind switch + { + DimensionKind.Dynamic when Value is Var var => $"%{var.Name}", + DimensionKind.Fixed => FixedValue.ToString(), + DimensionKind.Unknown => "?", + _ => "...", + }; + + /// + public override bool Equals(object? obj) + { + return obj is Dimension dimension && Equals(dimension); + } - public bool IsAssignableFrom(Dimension dimension) + /// + public bool Equals(Dimension? other) + { + return other is not null && (Kind, other.Value.Kind) switch { - if (IsUnknown) - { - return true; - } + (DimensionKind.Dynamic, DimensionKind.Dynamic) => Value == other.Value.Value, + (DimensionKind.Fixed, DimensionKind.Fixed) => FixedValue == other.Value.FixedValue, + (DimensionKind.Unknown, DimensionKind.Unknown) => true, + (_, _) => false, + }; + } - return dimension.Kind == DimensionKind.Fixed && Value == dimension.Value; - } + /// + public override int GetHashCode() + { + return IsFixed ? HashCode.Combine(Kind, FixedValue) : HashCode.Combine(Kind, Value); + } - public Expr ToExpr() => IsFixed ? FixedValue : Value; + public bool HasFixedValue(Predicate predicate) + { + return IsFixed && predicate(FixedValue); } + + public bool IsAssignableFrom(Dimension dimension) => + (Kind, dimension.Kind) switch + { + (DimensionKind.Dynamic, DimensionKind.Dynamic) => Value == dimension.Value, + (DimensionKind.Fixed, DimensionKind.Fixed) => FixedValue == dimension.FixedValue, + (DimensionKind.Unknown, _) => true, + (_, _) => false, + }; + + public Expr ToExpr() => IsFixed ? FixedValue : Value; } diff --git a/src/Nncase.Core/IR/Expr.Conversion.cs b/src/Nncase.Core/IR/Expr.Conversion.cs index f32270e40e..3455cbba79 100644 --- a/src/Nncase.Core/IR/Expr.Conversion.cs +++ b/src/Nncase.Core/IR/Expr.Conversion.cs @@ -104,14 +104,6 @@ public abstract partial class Expr /// Value. public static implicit operator Expr(bool value) => (Const)value; - /// - /// Create from a . - /// - /// Shape. - public static implicit operator Expr(Shape shape) => - shape.IsFixed ? Const.FromShape(shape) - : IR.F.Tensors.Stack(new IR.Tuple(shape.Select(x => x.ToExpr()).ToArray()), 0); - /// /// Create from an array of. /// diff --git a/src/Nncase.Core/IR/Expr.Operators.cs b/src/Nncase.Core/IR/Expr.Operators.cs index 14409ce2eb..195b07da08 100644 --- a/src/Nncase.Core/IR/Expr.Operators.cs +++ b/src/Nncase.Core/IR/Expr.Operators.cs @@ -31,8 +31,8 @@ public partial class Expr { TensorConst tc => Tensor.FromScalar(tc.Value.ElementType, tc.Value[indices]), TupleConst tc => tc.Value[(int)indices.Single()].AsTensor(), - ShapeConst sc => new DimensionConst(sc.Value[(int)indices.Single()]), - IR.Tuple t => t.Fields[(int)indices.Single()], + Shape shape => shape.Dimensions[(int)indices.Single()], + IR.Tuple t => t[(int)indices.Single()], Call { Target: Concat { Axis: 0 } } c when indices.Length == 1 => c[Concat.Input][indices[0]][0], Call { Target: Reshape } c when c[Reshape.Shape] is TensorConst tc && tc.Value.Length == 1 && tc.Value.ToScalar() == 1 => c[Reshape.Input], _ => this[indices.Select(x => (Expr)x).ToArray()], diff --git a/src/Nncase.Core/IR/Expr.cs b/src/Nncase.Core/IR/Expr.cs index b32b4a0a62..06189c2491 100644 --- a/src/Nncase.Core/IR/Expr.cs +++ b/src/Nncase.Core/IR/Expr.cs @@ -45,6 +45,7 @@ internal Expr(IEnumerable operands) _operands = operands.ToArray(); foreach (var operand in _operands) { + ValidateOperand(operand); operand.AddUser(this); } @@ -57,6 +58,7 @@ internal Expr(Expr[] operands) _operands = operands; foreach (var operand in _operands) { + ValidateOperand(operand); operand.AddUser(this); } @@ -285,6 +287,14 @@ protected virtual int GetHashCodeCore() return HashCode.Combine(GetType(), HashCode.Combine(Operands)); } + protected virtual void OnOperandsReplaced() + { + InvalidateTypeInference(); + InvalidateHashCodeCache(); + InvalidateRange(); + RefreshDepth(); + } + private bool IsDescendantOf(Expr other, Dictionary visited) { if (visited.TryGetValue(this, out var result)) @@ -319,12 +329,12 @@ private bool IsDescendantOf(Expr other) return IsDescendantOf(other, new Dictionary(ReferenceEqualityComparer.Instance)); } - private void OnOperandsReplaced() + private void ValidateOperand(Expr operand) { - InvalidateTypeInference(); - InvalidateHashCodeCache(); - InvalidateRange(); - RefreshDepth(); + if (operand is Shape) + { + throw new ArgumentException($"{operand.GetType()} can't be an operand."); + } } private void InvalidateTypeInference() diff --git a/src/Nncase.Core/IR/ExprCloner.g.cs b/src/Nncase.Core/IR/ExprCloner.g.cs index 902c3f4183..e92c84aed4 100644 --- a/src/Nncase.Core/IR/ExprCloner.g.cs +++ b/src/Nncase.Core/IR/ExprCloner.g.cs @@ -98,21 +98,15 @@ protected override Expr VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, T } /// - protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context) - { - return expr.With( - ); - } - - /// - protected override Expr VisitLeafShapeConst(ShapeConst expr, TContext context) + protected override Expr VisitLeafShape(IR.Shape expr, TContext context) { return expr.With( + dimensions: CloneArray(expr.Dimensions, context) ); } /// - protected override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context) + protected override Expr VisitLeafTensorConst(TensorConst expr, TContext context) { return expr.With( ); diff --git a/src/Nncase.Core/IR/ExprFunctor.g.cs b/src/Nncase.Core/IR/ExprFunctor.g.cs index b26a670a3b..0097357d83 100644 --- a/src/Nncase.Core/IR/ExprFunctor.g.cs +++ b/src/Nncase.Core/IR/ExprFunctor.g.cs @@ -64,19 +64,14 @@ public partial class ExprFunctor internal protected virtual TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => VisitBaseFunction(expr, context); /// - /// Visit . - /// - internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context); - - /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitShapeConst(ShapeConst expr, TContext context) => VisitConst(expr, context); + internal protected virtual TExprResult VisitShape(IR.Shape expr, TContext context) => DefaultVisit(expr, context); /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr, TContext context) => VisitConst(expr, context); + internal protected virtual TExprResult VisitTensorConst(TensorConst expr, TContext context) => VisitConst(expr, context); /// /// Visit . @@ -308,26 +303,19 @@ public partial class ExprFunctor /// internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr); /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default); + internal protected virtual TExprResult VisitShape(IR.Shape expr) => base.VisitShape(expr, default); /// - internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); - /// - /// Visit . - /// - internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default); - - /// - internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr); + internal protected sealed override TExprResult VisitShape(IR.Shape expr, Unit context) => VisitShape(expr); /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default); + internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default); /// - internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr); + internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); /// /// Visit . /// diff --git a/src/Nncase.Core/IR/ExprRewriter.cs b/src/Nncase.Core/IR/ExprRewriter.cs index ebef1b92e2..9ee7551908 100644 --- a/src/Nncase.Core/IR/ExprRewriter.cs +++ b/src/Nncase.Core/IR/ExprRewriter.cs @@ -22,8 +22,9 @@ public abstract partial class ExprRewriter : ExprVisitor class. /// /// Vist other functions. - public ExprRewriter(bool visitOtherFunctions = false) - : base(visitOtherFunctions) + /// Visit attributes. + public ExprRewriter(bool visitOtherFunctions = false, bool visitAttributes = false) + : base(visitOtherFunctions, visitAttributes) { } @@ -46,6 +47,18 @@ public Expr Rewrite(Expr expr, TContext context) return newExpr; } + public override IRType VisitTypeLeaf(AnyType type, TContext context) => type; + + public override IRType VisitTypeLeaf(CallableType type, TContext context) => type; + + public override IRType VisitTypeLeaf(InvalidType type, TContext context) => type; + + public override IRType VisitTypeLeaf(TensorType type, TContext context) => type; + + public override IRType VisitTypeLeaf(TupleType type, TContext context) => type; + + public override IRType VisitTypeLeaf(DistributedType type, TContext context) => type; + /// /// Default rewrite leaf routine. /// @@ -68,6 +81,20 @@ protected override void VisitOperands(Expr expr, TContext context) } } + protected override void VisitAttributes(Expr expr, TContext context) + { + var type = expr.RawCheckedType; + if (type != null) + { + var newType = VisitType(type, context); + if (!ReferenceEquals(type, newType)) + { + expr.CheckedType = newType; + SetMutated(); + } + } + } + private void DCE(Expr root, ExprScope exprScope) { // using var exprPin = new ExprPinner(root); @@ -84,8 +111,9 @@ public abstract partial class ExprRewriter : ExprRewriter /// Initializes a new instance of the class. /// /// Vist other functions. - protected ExprRewriter(bool visitOtherFunctions = false) - : base(visitOtherFunctions) + /// Visit attributes. + protected ExprRewriter(bool visitOtherFunctions = false, bool visitAttributes = false) + : base(visitOtherFunctions, visitAttributes) { } diff --git a/src/Nncase.Core/IR/ExprRewriter.g.cs b/src/Nncase.Core/IR/ExprRewriter.g.cs index a41f4a5d1b..197d4f8e7c 100644 --- a/src/Nncase.Core/IR/ExprRewriter.g.cs +++ b/src/Nncase.Core/IR/ExprRewriter.g.cs @@ -74,21 +74,15 @@ protected sealed override Expr VisitLeafPrimFunctionWrapper(PrimFunctionWrapper } /// - protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext context) + protected sealed override Expr VisitLeafShape(IR.Shape expr, TContext context) { - return RewriteLeafTensorConst(expr, context); + return RewriteLeafShape(expr, context); } /// - protected sealed override Expr VisitLeafShapeConst(ShapeConst expr, TContext context) - { - return RewriteLeafShapeConst(expr, context); - } - - /// - protected sealed override Expr VisitLeafDimensionConst(DimensionConst expr, TContext context) + protected sealed override Expr VisitLeafTensorConst(TensorConst expr, TContext context) { - return RewriteLeafDimensionConst(expr, context); + return RewriteLeafTensorConst(expr, context); } /// @@ -328,19 +322,14 @@ protected sealed override Expr VisitLeafBufferOf(Buffers.BufferOf expr, TContext protected virtual Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => RewriteLeafBaseFunction(expr, context); /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context); - - /// - /// Rewrite leaf . + /// Rewrite leaf . /// - protected virtual Expr RewriteLeafShapeConst(ShapeConst expr, TContext context) => RewriteLeafConst(expr, context); + protected virtual Expr RewriteLeafShape(IR.Shape expr, TContext context) => DefaultRewriteLeaf(expr, context); /// - /// Rewrite leaf . + /// Rewrite leaf . /// - protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr, TContext context) => RewriteLeafConst(expr, context); + protected virtual Expr RewriteLeafTensorConst(TensorConst expr, TContext context) => RewriteLeafConst(expr, context); /// /// Rewrite leaf . @@ -582,28 +571,20 @@ public partial class ExprRewriter protected sealed override Expr RewriteLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => RewriteLeafPrimFunctionWrapper(expr); /// - /// Rewrite leaf . + /// Rewrite leaf . /// - protected virtual Expr RewriteLeafTensorConst(TensorConst expr) => RewriteLeafConst(expr); + protected virtual Expr RewriteLeafShape(IR.Shape expr) => DefaultRewriteLeaf(expr); /// - protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr); + protected sealed override Expr RewriteLeafShape(IR.Shape expr, Unit context) => RewriteLeafShape(expr); /// - /// Rewrite leaf . - /// - protected virtual Expr RewriteLeafShapeConst(ShapeConst expr) => RewriteLeafConst(expr); - - /// - protected sealed override Expr RewriteLeafShapeConst(ShapeConst expr, Unit context) => RewriteLeafShapeConst(expr); - - /// - /// Rewrite leaf . + /// Rewrite leaf . /// - protected virtual Expr RewriteLeafDimensionConst(DimensionConst expr) => RewriteLeafConst(expr); + protected virtual Expr RewriteLeafTensorConst(TensorConst expr) => RewriteLeafConst(expr); /// - protected sealed override Expr RewriteLeafDimensionConst(DimensionConst expr, Unit context) => RewriteLeafDimensionConst(expr); + protected sealed override Expr RewriteLeafTensorConst(TensorConst expr, Unit context) => RewriteLeafTensorConst(expr); /// /// Rewrite leaf . diff --git a/src/Nncase.Core/IR/ExprVisitor.cs b/src/Nncase.Core/IR/ExprVisitor.cs index c1619d76fa..4ea1d8f821 100644 --- a/src/Nncase.Core/IR/ExprVisitor.cs +++ b/src/Nncase.Core/IR/ExprVisitor.cs @@ -20,14 +20,17 @@ namespace Nncase.IR; public abstract partial class ExprVisitor : ExprFunctor { private readonly bool _visitOtherFunctions; + private readonly bool _visitAttributes; /// /// Initializes a new instance of the class. /// /// Vist other functions. - public ExprVisitor(bool visitOtherFunctions = false) + /// Visit attributes. + public ExprVisitor(bool visitOtherFunctions = false, bool visitAttributes = false) { _visitOtherFunctions = visitOtherFunctions; + _visitAttributes = visitAttributes; } /// @@ -79,6 +82,17 @@ public override TTypeResult VisitType(InvalidType type, TContext context) return MarkVisited(type, VisitTypeLeaf(type, context)); } + /// + public override TTypeResult VisitType(NoneType type, TContext context) + { + if (HasVisited(type, out var result)) + { + return result; + } + + return MarkVisited(type, VisitTypeLeaf(type, context)); + } + /// public override TTypeResult VisitType(TensorType type, TContext context) { @@ -87,6 +101,7 @@ public override TTypeResult VisitType(TensorType type, TContext context) return result; } + Visit(type.Shape, context); return MarkVisited(type, VisitTypeLeaf(type, context)); } @@ -106,6 +121,18 @@ public override TTypeResult VisitType(TupleType type, TContext context) return MarkVisited(type, VisitTypeLeaf(type, context)); } + /// + public override TTypeResult VisitType(DistributedType type, TContext context) + { + if (HasVisited(type, out var result)) + { + return result; + } + + VisitType(type.TensorType, context); + return MarkVisited(type, VisitTypeLeaf(type, context)); + } + /// /// Visit any type leaf. /// @@ -116,6 +143,11 @@ public override TTypeResult VisitType(TupleType type, TContext context) /// public virtual TTypeResult VisitTypeLeaf(InvalidType type, TContext context) => DefaultVisitTypeLeaf(type, context); + /// + /// Visit none type leaf. + /// + public virtual TTypeResult VisitTypeLeaf(NoneType type, TContext context) => DefaultVisitTypeLeaf(type, context); + /// /// Visit tensor type leaf. /// @@ -131,6 +163,11 @@ public override TTypeResult VisitType(TupleType type, TContext context) /// public virtual TTypeResult VisitTypeLeaf(CallableType type, TContext context) => DefaultVisitTypeLeaf(type, context); + /// + /// Visit distributed type leaf. + /// + public virtual TTypeResult VisitTypeLeaf(DistributedType type, TContext context) => DefaultVisitTypeLeaf(type, context); + /// /// Default visit leaf routine. /// @@ -191,6 +228,12 @@ protected bool CanVisitFunctionBody(BaseFunction baseFunction) return ReferenceEquals(baseFunction, VisitRoot); } + protected bool CanVisitAttributes(Expr expr) + { + // Avoid infinite loop + return _visitAttributes && expr is not Shape; + } + /// /// Default leaf visit routine. /// @@ -217,6 +260,17 @@ protected virtual void VisitOperands(Expr expr, TContext context) Visit(operand, context); } } + + protected virtual void VisitAttributes(Expr expr, TContext context) + { + if (_visitAttributes) + { + if (expr.RawCheckedType != null) + { + VisitType(expr.RawCheckedType, context); + } + } + } } /// @@ -230,8 +284,9 @@ public abstract partial class ExprVisitor : ExprVisito /// Initializes a new instance of the class. /// /// Vist other functions. - public ExprVisitor(bool visitOtherFunctions = false) - : base(visitOtherFunctions) + /// Visit attributes. + public ExprVisitor(bool visitOtherFunctions = false, bool visitAttributes = false) + : base(visitOtherFunctions, visitAttributes) { } diff --git a/src/Nncase.Core/IR/ExprVisitor.g.cs b/src/Nncase.Core/IR/ExprVisitor.g.cs index 32d702eb7d..f1ede21f58 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.cs +++ b/src/Nncase.Core/IR/ExprVisitor.g.cs @@ -17,6 +17,11 @@ public partial class ExprVisitor protected internal override TExprResult VisitCall(Call expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafCall(expr, context); } @@ -28,6 +33,11 @@ protected internal override TExprResult VisitFunction(Function expr, TContext co VisitOperands(expr, context); } + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafFunction(expr, context); } @@ -39,6 +49,11 @@ protected internal override TExprResult VisitFusion(Fusion expr, TContext contex VisitOperands(expr, context); } + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafFusion(expr, context); } @@ -46,6 +61,11 @@ protected internal override TExprResult VisitFusion(Fusion expr, TContext contex protected internal override TExprResult VisitIf(If expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafIf(expr, context); } @@ -53,6 +73,11 @@ protected internal override TExprResult VisitIf(If expr, TContext context) protected internal override TExprResult VisitMarker(Marker expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafMarker(expr, context); } @@ -60,6 +85,11 @@ protected internal override TExprResult VisitMarker(Marker expr, TContext contex protected internal override TExprResult VisitNone(None expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafNone(expr, context); } @@ -67,6 +97,11 @@ protected internal override TExprResult VisitNone(None expr, TContext context) protected internal override TExprResult VisitOp(Op expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafOp(expr, context); } @@ -78,34 +113,47 @@ protected internal override TExprResult VisitPrimFunctionWrapper(PrimFunctionWra VisitOperands(expr, context); } + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafPrimFunctionWrapper(expr, context); } /// - protected internal override TExprResult VisitTensorConst(TensorConst expr, TContext context) + protected internal override TExprResult VisitShape(IR.Shape expr, TContext context) { VisitOperands(expr, context); - return VisitLeafTensorConst(expr, context); - } + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } - /// - protected internal override TExprResult VisitShapeConst(ShapeConst expr, TContext context) - { - VisitOperands(expr, context); - return VisitLeafShapeConst(expr, context); + return VisitLeafShape(expr, context); } /// - protected internal override TExprResult VisitDimensionConst(DimensionConst expr, TContext context) + protected internal override TExprResult VisitTensorConst(TensorConst expr, TContext context) { VisitOperands(expr, context); - return VisitLeafDimensionConst(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + + return VisitLeafTensorConst(expr, context); } /// protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafTuple(expr, context); } @@ -113,6 +161,11 @@ protected internal override TExprResult VisitTuple(IR.Tuple expr, TContext conte protected internal override TExprResult VisitTupleConst(TupleConst expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafTupleConst(expr, context); } @@ -120,6 +173,11 @@ protected internal override TExprResult VisitTupleConst(TupleConst expr, TContex protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafMemSpan(expr, context); } @@ -127,6 +185,11 @@ protected internal override TExprResult VisitMemSpan(TIR.MemSpan expr, TContext protected internal override TExprResult VisitVar(Var expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafVar(expr, context); } @@ -134,6 +197,11 @@ protected internal override TExprResult VisitVar(Var expr, TContext context) protected internal override TExprResult VisitBlock(TIR.Block expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafBlock(expr, context); } @@ -141,6 +209,11 @@ protected internal override TExprResult VisitBlock(TIR.Block expr, TContext cont protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafBuffer(expr, context); } @@ -148,6 +221,11 @@ protected internal override TExprResult VisitBuffer(TIR.Buffer expr, TContext co protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafBufferRegion(expr, context); } @@ -155,6 +233,11 @@ protected internal override TExprResult VisitBufferRegion(TIR.BufferRegion expr, protected internal override TExprResult VisitFor(TIR.For expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafFor(expr, context); } @@ -162,6 +245,11 @@ protected internal override TExprResult VisitFor(TIR.For expr, TContext context) protected internal override TExprResult VisitIfThenElse(TIR.IfThenElse expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafIfThenElse(expr, context); } @@ -169,6 +257,11 @@ protected internal override TExprResult VisitIfThenElse(TIR.IfThenElse expr, TCo protected internal override TExprResult VisitLet(TIR.Let expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafLet(expr, context); } @@ -180,6 +273,11 @@ protected internal override TExprResult VisitPrimFunction(TIR.PrimFunction expr, VisitOperands(expr, context); } + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafPrimFunction(expr, context); } @@ -187,6 +285,11 @@ protected internal override TExprResult VisitPrimFunction(TIR.PrimFunction expr, protected internal override TExprResult VisitSequential(TIR.Sequential expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafSequential(expr, context); } @@ -194,6 +297,11 @@ protected internal override TExprResult VisitSequential(TIR.Sequential expr, TCo protected internal override TExprResult VisitRange(TIR.Range expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafRange(expr, context); } @@ -201,6 +309,11 @@ protected internal override TExprResult VisitRange(TIR.Range expr, TContext cont protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafIterVar(expr, context); } @@ -208,6 +321,11 @@ protected internal override TExprResult VisitIterVar(TIR.IterVar expr, TContext protected internal override TExprResult VisitAffineDim(Affine.AffineDim expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineDim(expr, context); } @@ -215,6 +333,11 @@ protected internal override TExprResult VisitAffineDim(Affine.AffineDim expr, TC protected internal override TExprResult VisitAffineExtent(Affine.AffineExtent expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineExtent(expr, context); } @@ -222,6 +345,11 @@ protected internal override TExprResult VisitAffineExtent(Affine.AffineExtent ex protected internal override TExprResult VisitAffineSymbol(Affine.AffineSymbol expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineSymbol(expr, context); } @@ -229,6 +357,11 @@ protected internal override TExprResult VisitAffineSymbol(Affine.AffineSymbol ex protected internal override TExprResult VisitAffineConstant(Affine.AffineConstant expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineConstant(expr, context); } @@ -236,6 +369,11 @@ protected internal override TExprResult VisitAffineConstant(Affine.AffineConstan protected internal override TExprResult VisitAffineAddBinary(Affine.AffineAddBinary expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineAddBinary(expr, context); } @@ -243,6 +381,11 @@ protected internal override TExprResult VisitAffineAddBinary(Affine.AffineAddBin protected internal override TExprResult VisitAffineMulBinary(Affine.AffineMulBinary expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineMulBinary(expr, context); } @@ -250,6 +393,11 @@ protected internal override TExprResult VisitAffineMulBinary(Affine.AffineMulBin protected internal override TExprResult VisitAffineDivBinary(Affine.AffineDivBinary expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineDivBinary(expr, context); } @@ -257,6 +405,11 @@ protected internal override TExprResult VisitAffineDivBinary(Affine.AffineDivBin protected internal override TExprResult VisitAffineDomain(Affine.AffineDomain expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineDomain(expr, context); } @@ -264,6 +417,11 @@ protected internal override TExprResult VisitAffineDomain(Affine.AffineDomain ex protected internal override TExprResult VisitAffineRange(Affine.AffineRange expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineRange(expr, context); } @@ -271,6 +429,11 @@ protected internal override TExprResult VisitAffineRange(Affine.AffineRange expr protected internal override TExprResult VisitAffineMap(Affine.AffineMap expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineMap(expr, context); } @@ -278,6 +441,11 @@ protected internal override TExprResult VisitAffineMap(Affine.AffineMap expr, TC protected internal override TExprResult VisitAffineRelation(Affine.AffineRelation expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafAffineRelation(expr, context); } @@ -285,6 +453,11 @@ protected internal override TExprResult VisitAffineRelation(Affine.AffineRelatio protected internal override TExprResult VisitGrid(Affine.Grid expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafGrid(expr, context); } @@ -292,6 +465,11 @@ protected internal override TExprResult VisitGrid(Affine.Grid expr, TContext con protected internal override TExprResult VisitLoad(Affine.Load expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafLoad(expr, context); } @@ -299,6 +477,11 @@ protected internal override TExprResult VisitLoad(Affine.Load expr, TContext con protected internal override TExprResult VisitFor(Affine.For expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafFor(expr, context); } @@ -306,6 +489,11 @@ protected internal override TExprResult VisitFor(Affine.For expr, TContext conte protected internal override TExprResult VisitBufferOf(Buffers.BufferOf expr, TContext context) { VisitOperands(expr, context); + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeafBufferOf(expr, context); } @@ -360,19 +548,14 @@ protected internal override TExprResult VisitBufferOf(Buffers.BufferOf expr, TCo protected virtual TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, TContext context) => VisitLeafBaseFunction(expr, context); /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context); - - /// - /// Visit leaf . + /// Visit leaf . /// - protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr, TContext context) => VisitLeafConst(expr, context); + protected virtual TExprResult VisitLeafShape(IR.Shape expr, TContext context) => DefaultVisitLeaf(expr, context); /// - /// Visit leaf . + /// Visit leaf . /// - protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr, TContext context) => VisitLeafConst(expr, context); + protected virtual TExprResult VisitLeafTensorConst(TensorConst expr, TContext context) => VisitLeafConst(expr, context); /// /// Visit leaf . @@ -590,26 +773,19 @@ public partial class ExprVisitor /// internal protected sealed override TExprResult VisitPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitPrimFunctionWrapper(expr); /// - /// Visit . - /// - internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default); - - /// - internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); - /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitShapeConst(ShapeConst expr) => base.VisitShapeConst(expr, default); + internal protected virtual TExprResult VisitShape(IR.Shape expr) => base.VisitShape(expr, default); /// - internal protected sealed override TExprResult VisitShapeConst(ShapeConst expr, Unit context) => VisitShapeConst(expr); + internal protected sealed override TExprResult VisitShape(IR.Shape expr, Unit context) => VisitShape(expr); /// - /// Visit . + /// Visit . /// - internal protected virtual TExprResult VisitDimensionConst(DimensionConst expr) => base.VisitDimensionConst(expr, default); + internal protected virtual TExprResult VisitTensorConst(TensorConst expr) => base.VisitTensorConst(expr, default); /// - internal protected sealed override TExprResult VisitDimensionConst(DimensionConst expr, Unit context) => VisitDimensionConst(expr); + internal protected sealed override TExprResult VisitTensorConst(TensorConst expr, Unit context) => VisitTensorConst(expr); /// /// Visit . /// @@ -894,28 +1070,20 @@ public partial class ExprVisitor protected sealed override TExprResult VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr, Unit context) => VisitLeafPrimFunctionWrapper(expr); /// - /// Visit leaf . + /// Visit leaf . /// - protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default); + protected virtual TExprResult VisitLeafShape(IR.Shape expr) => base.VisitLeafShape(expr, default); /// - protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr); + protected sealed override TExprResult VisitLeafShape(IR.Shape expr, Unit context) => VisitLeafShape(expr); /// - /// Visit leaf . - /// - protected virtual TExprResult VisitLeafShapeConst(ShapeConst expr) => base.VisitLeafShapeConst(expr, default); - - /// - protected sealed override TExprResult VisitLeafShapeConst(ShapeConst expr, Unit context) => VisitLeafShapeConst(expr); - - /// - /// Visit leaf . + /// Visit leaf . /// - protected virtual TExprResult VisitLeafDimensionConst(DimensionConst expr) => base.VisitLeafDimensionConst(expr, default); + protected virtual TExprResult VisitLeafTensorConst(TensorConst expr) => base.VisitLeafTensorConst(expr, default); /// - protected sealed override TExprResult VisitLeafDimensionConst(DimensionConst expr, Unit context) => VisitLeafDimensionConst(expr); + protected sealed override TExprResult VisitLeafTensorConst(TensorConst expr, Unit context) => VisitLeafTensorConst(expr); /// /// Visit leaf . diff --git a/src/Nncase.Core/IR/ExprVisitor.g.tt b/src/Nncase.Core/IR/ExprVisitor.g.tt index b8d267b993..4cc7d1bbb4 100644 --- a/src/Nncase.Core/IR/ExprVisitor.g.tt +++ b/src/Nncase.Core/IR/ExprVisitor.g.tt @@ -49,6 +49,11 @@ foreach (var ir in irs.Where(x => x.IsDerived)) <# } #> + if (CanVisitAttributes(expr)) + { + VisitAttributes(expr, context); + } + return VisitLeaf<#=ir.Name#>(expr, context); } diff --git a/src/Nncase.Core/IR/ExprWalker.cs b/src/Nncase.Core/IR/ExprWalker.cs index 07d3427df4..ac90a419a5 100644 --- a/src/Nncase.Core/IR/ExprWalker.cs +++ b/src/Nncase.Core/IR/ExprWalker.cs @@ -16,8 +16,9 @@ public abstract class ExprWalker : ExprVisitor /// Initializes a new instance of the class. /// /// Vist other functions. - public ExprWalker(bool visitOtherFunctions = false) - : base(visitOtherFunctions) + /// Visit attributes. + public ExprWalker(bool visitOtherFunctions = false, bool visitAttributes = false) + : base(visitOtherFunctions, visitAttributes) { } @@ -30,8 +31,9 @@ public abstract class ExprWalker : ExprVisitor /// Initializes a new instance of the class. /// /// Vist other functions. - public ExprWalker(bool visitOtherFunctions = false) - : base(visitOtherFunctions) + /// Visit attributes. + public ExprWalker(bool visitOtherFunctions = false, bool visitAttributes = false) + : base(visitOtherFunctions, visitAttributes) { } diff --git a/src/Nncase.Core/IR/IRList.csv b/src/Nncase.Core/IR/IRList.csv index 253041d488..7ea4982e63 100644 --- a/src/Nncase.Core/IR/IRList.csv +++ b/src/Nncase.Core/IR/IRList.csv @@ -8,9 +8,8 @@ Marker,true,false,Default,,Target;Attribute None,true,false,Default,, Op,true,false,Default,, PrimFunctionWrapper,true,true,BaseFunction,,Target +Shape,true,false,Default,IR.,@Dimensions TensorConst,true,false,Const,, -ShapeConst,true,false,Const,, -DimensionConst,true,false,Const,, Tuple,true,false,Default,IR.,@Fields TupleConst,true,false,Const,, MemSpan,true,false,Default,TIR.,Start;Size; diff --git a/src/Nncase.Core/IR/NN/OneHot.cs b/src/Nncase.Core/IR/NN/OneHot.cs index a6ecf22c72..67e2681106 100644 --- a/src/Nncase.Core/IR/NN/OneHot.cs +++ b/src/Nncase.Core/IR/NN/OneHot.cs @@ -25,7 +25,7 @@ public sealed partial class OneHot : Op /// /// Gets depth. /// - public static readonly ParameterInfo Depth = new(typeof(OneHot), 1, "depth"); + public static new readonly ParameterInfo Depth = new(typeof(OneHot), 1, "depth"); /// /// Gets values. diff --git a/src/Nncase.Core/IR/NN/SpaceToBatch.cs b/src/Nncase.Core/IR/NN/SpaceToBatch.cs index 8fe5fd82be..17aa97544e 100644 --- a/src/Nncase.Core/IR/NN/SpaceToBatch.cs +++ b/src/Nncase.Core/IR/NN/SpaceToBatch.cs @@ -31,5 +31,5 @@ public sealed partial class SpaceToBatch : Op /// /// Gets paddings. /// - public static readonly ParameterInfo Paddings = new(typeof(SpaceToBatch), 2, "paddings", HasShape(new[] { Dimension.Any, 2 }) & IsIntegral()); + public static readonly ParameterInfo Paddings = new(typeof(SpaceToBatch), 2, "paddings", HasShape(new[] { Dimension.Unknown, 2 }) & IsIntegral()); } diff --git a/src/Nncase.Core/IR/Shape.cs b/src/Nncase.Core/IR/Shape.cs index 9025a17280..ff268e2019 100644 --- a/src/Nncase.Core/IR/Shape.cs +++ b/src/Nncase.Core/IR/Shape.cs @@ -8,482 +8,514 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using CommunityToolkit.HighPerformance.Helpers; using NetFabric.Hyperlinq; using Nncase.IR.Tensors; +using Nncase.Utilities; -namespace Nncase.IR +namespace Nncase.IR; + +/// +/// Shape kind. +/// +public enum ShapeKind { /// - /// Shape kind. + /// Invalid shape. /// - public enum ShapeKind - { - /// - /// Invalid shape. - /// - Invalid, + Invalid, - /// - /// Unranked shape. - /// - Unranked, + /// + /// Unranked shape. + /// + Unranked, - /// - /// Shape contains unknown dimensions. - /// - HasUnknownDimension, + /// + /// Shape contains unknown dimensions. + /// + HasUnknownDimension, - /// - /// Fixed shape. - /// - Fixed, - } + /// + /// Fixed shape. + /// + Fixed, +} - public record struct FixedAndDynamicDimension(long Fixed, Dimension? Dynamic) - { - public static implicit operator FixedAndDynamicDimension((long Fixed, Dimension? Dynamic) value) => new FixedAndDynamicDimension(value.Fixed, value.Dynamic); +public record struct FixedAndDynamicDimension(long Fixed, Dimension? Dynamic) +{ + public static implicit operator FixedAndDynamicDimension((long Fixed, Dimension? Dynamic) value) => new FixedAndDynamicDimension(value.Fixed, value.Dynamic); - public static FixedAndDynamicDimension operator *(FixedAndDynamicDimension a, FixedAndDynamicDimension b) + public static FixedAndDynamicDimension operator *(FixedAndDynamicDimension a, FixedAndDynamicDimension b) + { + var dyn = (a.Dynamic, b.Dynamic) switch { - var dyn = (a.Dynamic, b.Dynamic) switch - { - (null, null) => null, - (null, Dimension x) => x, - (Dimension x, null) => x, - (Dimension x, Dimension y) => x * y, - }; - return new FixedAndDynamicDimension(a.Fixed * b.Fixed, dyn); - } + (null, null) => (Dimension?)null, + (null, Dimension x) => x, + (Dimension x, null) => x, + (Dimension x, Dimension y) => x * y, + }; + return new FixedAndDynamicDimension(a.Fixed * b.Fixed, dyn); + } - public static FixedAndDynamicDimension operator /(FixedAndDynamicDimension a, long b) + public static FixedAndDynamicDimension operator /(FixedAndDynamicDimension a, long b) + { + if (a.Fixed % b == 0 || a.Dynamic is null) { - if (a.Fixed % b == 0 || a.Dynamic is null) - { - return new FixedAndDynamicDimension(a.Fixed / b, a.Dynamic); - } - - return new FixedAndDynamicDimension(1, a.Fixed * a.Dynamic.Value / b); + return new FixedAndDynamicDimension(a.Fixed / b, a.Dynamic); } - public static FixedAndDynamicDimension operator /(FixedAndDynamicDimension a, FixedAndDynamicDimension b) - { - if (a.Fixed % b.Fixed == 0) - { - return (a.Dynamic, b.Dynamic) switch - { - (null, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null), - (null, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / b.Fixed / y), - (Dimension x, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, x), - (Dimension x, Dimension y) when x == y => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null), - (Dimension x, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / b.Fixed * x / y), - }; - } + return new FixedAndDynamicDimension(1, a.Fixed * a.Dynamic.Value / b); + } + public static FixedAndDynamicDimension operator /(FixedAndDynamicDimension a, FixedAndDynamicDimension b) + { + if (a.Fixed % b.Fixed == 0) + { return (a.Dynamic, b.Dynamic) switch { (null, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null), - (null, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / (b.Fixed * y)), - (Dimension x, null) => new FixedAndDynamicDimension(1, a.Fixed * x / b.Fixed), + (null, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / b.Fixed / y), + (Dimension x, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, x), (Dimension x, Dimension y) when x == y => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null), - (Dimension x, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed * x / (b.Fixed * y)), + (Dimension x, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / b.Fixed * x / y), }; } - public static FixedAndDynamicDimension Abs(FixedAndDynamicDimension value) => - new(System.Math.Abs(value.Fixed), value.Dynamic is null ? null : Dimension.Abs(value.Dynamic.Value)); - - public Dimension ToDimension() + return (a.Dynamic, b.Dynamic) switch { - return Dynamic is null ? new Dimension(Fixed) : new Dimension(Fixed) * Dynamic.Value; - } + (null, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null), + (null, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / (b.Fixed * y)), + (Dimension x, null) => new FixedAndDynamicDimension(1, a.Fixed * x / b.Fixed), + (Dimension x, Dimension y) when x == y => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null), + (Dimension x, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed * x / (b.Fixed * y)), + }; + } - public Expr ToExpr() - { - return Dynamic is null ? Fixed : Fixed * Dynamic.Value; - } + public static FixedAndDynamicDimension Abs(FixedAndDynamicDimension value) => + new(System.Math.Abs(value.Fixed), value.Dynamic is null ? (Dimension?)null : Dimension.Abs(value.Dynamic.Value)); + + public Dimension ToDimension() + { + return Dynamic is null ? new Dimension(Fixed) : new Dimension(Fixed) * Dynamic.Value; + } + + public Expr ToExpr() + { + return Dynamic is null ? Fixed : Fixed * Dynamic.Value.ToExpr(); } +} +/// +/// Tensor shape. +/// +public sealed class Shape : Expr, IEquatable, IReadOnlyList +{ /// - /// Tensor shape. + /// Initializes a new instance of the class. /// - public sealed class Shape : IStructuralEquatable, IReadOnlyList, IEquatable, IEnumerable + /// Dimensions. + public Shape(ReadOnlySpan dimensions) + : this(dimensions.ToArray()) { - private readonly ImmutableArray _dimensions; + } - private readonly int _hashcode; + /// + /// Initializes a new instance of the class. + /// init from the dimensions + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(ReadOnlySpan dimensions) + : this(dimensions.AsValueEnumerable().Select(x => (Expr)x).ToArray()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(ReadOnlySpan dimensions) - { - Kind = KindOf(dimensions); - _dimensions = ImmutableArray.Create(dimensions.ToArray()); - _hashcode = StructuralComparisons.StructuralEqualityComparer.GetHashCode(_dimensions); - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(ReadOnlySpan dimensions) + : this(dimensions.AsValueEnumerable().Select(i => (Expr)i).ToArray()) + { + } - /// - /// Initializes a new instance of the class. - /// init from the dimensions - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(ReadOnlySpan dimensions) - : this(dimensions.AsValueEnumerable().Select(x => new Dimension(x)).ToArray()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(ReadOnlySpan dimensions) + : this(dimensions.AsValueEnumerable().Select(x => x.ToExpr()).ToArray()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(ReadOnlySpan dimensions) - : this(dimensions.AsValueEnumerable().Select(i => (int)i).ToArray()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(params Expr[] dimensions) + : base(dimensions.ToArray()) + { + RefreshKind(); + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(params Dimension[] dimensions) - : this(dimensions.AsSpan()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(params int[] dimensions) + : this(dimensions.AsSpan()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(params int[] dimensions) - : this(dimensions.AsSpan()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(params long[] dimensions) + : this(dimensions.AsSpan()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(IEnumerable dimensions) - : this(dimensions.ToArray()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(params Dimension[] dimensions) + : this(dimensions.AsSpan()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(IEnumerable dimensions) - : this(dimensions.ToArray()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(IEnumerable dimensions) + : this(dimensions.ToArray()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(IEnumerable dimensions) - : this(dimensions.Select(i => (Dimension)i).ToArray()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(IEnumerable dimensions) + : this(dimensions.Select(x => (Expr)x).ToArray()) + { + } - /// - /// Initializes a new instance of the class. - /// - /// Dimensions. - public Shape(ReadOnlySpan dimensions) - : this(dimensions.AsValueEnumerable().Select(i => (Dimension)i).ToArray()) - { - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(IEnumerable dimensions) + : this(dimensions.Select(x => (Expr)x).ToArray()) + { + } - private Shape(ShapeKind kind, IEnumerable dimensions) - { - Kind = kind; - _dimensions = dimensions.ToImmutableArray(); - _hashcode = StructuralComparisons.StructuralEqualityComparer.GetHashCode(_dimensions); - } + /// + /// Initializes a new instance of the class. + /// + /// Dimensions. + public Shape(IEnumerable dimensions) + : this(dimensions.Select(x => x.ToExpr()).ToArray()) + { + } - /// - /// Gets an invalid shape. - /// - public static Shape Invalid { get; } = new Shape(ShapeKind.Invalid, new List()); - - /// - /// Gets an unranked shape. - /// - public static Shape Unranked { get; } = new Shape(ShapeKind.Unranked, new List()); - - /// - /// Gets a scalar shape. - /// - public static Shape Scalar { get; } = new Shape(ShapeKind.Fixed, new List()); - - /// - /// Gets kind. - /// - public ShapeKind Kind { get; private set; } - - /// - /// Gets a value indicating whether is readonly. - /// - public bool IsReadOnly => true; - - /// - /// Gets a value indicating whether fixed. - /// - public bool IsFixed => Kind == ShapeKind.Fixed; - - /// - /// Gets a value indicating whether invalid. - /// - public bool IsInvalid => Kind == ShapeKind.Invalid; - - /// - /// Gets a value indicating whether unranked. - /// - public bool IsUnranked => Kind == ShapeKind.Unranked; - - /// - /// Gets a value indicating whether has unknown dimension. - /// - public bool HasUnknownDimension => Kind == ShapeKind.HasUnknownDimension; - - /// - /// Gets a value indicating whether ranked. - /// - public bool IsRanked => IsFixed || HasUnknownDimension; - - /// - /// Gets a value indicating whether scalar. - /// - public bool IsScalar => IsFixed && _dimensions.Length == 0; - - /// - /// Gets rank. - /// - public int Rank => _dimensions.Length; - - /// - /// Gets get Total Elements. - /// - public long Size => Enumerable.Range(0, Rank).Aggregate(1L, (size, i) => size * _dimensions[i].FixedValue); - - /// - public int Count => ((IReadOnlyCollection)_dimensions).Count; - - /// - public Dimension this[int index] => - index >= 0 - ? ((IReadOnlyList)_dimensions)[index] - : ((IReadOnlyList)_dimensions)[Rank + index]; - - public static implicit operator ReadOnlySpan(Shape shape) => shape._dimensions.Select(x => x.FixedValue).ToArray(); - - public static implicit operator Shape(Dimension[] dimensions) => new Shape(dimensions); - - public static implicit operator Shape(int[] dimensions) => new Shape(dimensions); - - public static implicit operator Shape(long[] dimensions) => new Shape(dimensions); - - public static bool operator ==(Shape lhs, Shape rhs) - { - return lhs.Equals(rhs); - } + private Shape(ShapeKind kind) + : base(Array.Empty()) + { + Kind = kind; + } - public static bool operator !=(Shape lhs, Shape rhs) - { - return !(lhs == rhs); - } + /// + /// Gets an invalid shape. + /// + public static Shape Invalid { get; } = new Shape(ShapeKind.Invalid); - /// - /// Gets a shape with rank unknwon dimension. - /// - public static Shape Unknown(int rank) + private static readonly Shape _unranked = new Shape(ShapeKind.Unranked); + + /// + /// Gets an unranked shape. + /// + public static Shape Unranked + { + get { - return new Shape(ShapeKind.HasUnknownDimension, Enumerable.Range(0, rank).Select(x => Dimension.Unknown())); + Console.WriteLine(); + return _unranked; } + } - /// - /// Gets a shape with rank unknwon dimension. - /// - public static Shape FromExpr(Expr value) - { - if (value is TensorConst tc) - { - return new Shape(tc.Value.ToArray()); - } - else if (value is Call { Target: Concat } concat) - { - if (concat.Arguments[Concat.Input.Index] is Tuple tuple) - { - return new Shape(tuple.Fields); - } - } + /// + /// Gets a scalar shape. + /// + public static Shape Scalar { get; } = new Shape(ShapeKind.Fixed); - var shape = value.CheckedShape; - if (shape.Rank != 1 || !shape.IsFixed) - { - // throw new ArgumentException($"Invalid shape expr: {value}", nameof(value)); - return Shape.Unranked; - } + /// + /// Gets dimensions. + /// + public ReadOnlySpan Dimensions => Operands; - var rank = (int)shape[0].FixedValue; - return new Shape(Enumerable.Range(0, rank).Select(x => (Dimension)value[x])); - } + /// + /// Gets kind. + /// + public ShapeKind Kind { get; private set; } - /// - /// Get Prod. - /// - public Dimension Prod() - { - return _dimensions.Aggregate(new Dimension(1), (x, y) => x * y); - } + /// + /// Gets a value indicating whether is readonly. + /// + public bool IsReadOnly => true; - public FixedAndDynamicDimension ProdFixedAndDynamic() - { - var fixedValue = 1L; - var dynamicValue = new Dimension(1); - foreach (var dim in _dimensions) - { - if (dim.IsFixed) - { - fixedValue *= dim.FixedValue; - } - else - { - dynamicValue *= dim; - } - } + /// + /// Gets a value indicating whether fixed. + /// + public bool IsFixed => Kind == ShapeKind.Fixed; - return new(fixedValue, dynamicValue.IsFixed ? null : dynamicValue); - } + /// + /// Gets a value indicating whether invalid. + /// + public bool IsInvalid => Kind == ShapeKind.Invalid; + + /// + /// Gets a value indicating whether unranked. + /// + public bool IsUnranked => Kind == ShapeKind.Unranked; + + /// + /// Gets a value indicating whether has unknown dimension. + /// + public bool HasUnknownDimension => Kind == ShapeKind.HasUnknownDimension; + + /// + /// Gets a value indicating whether ranked. + /// + public bool IsRanked => IsFixed || HasUnknownDimension; + + /// + /// Gets a value indicating whether scalar. + /// + public bool IsScalar => IsFixed && Dimensions.Length == 0; + + /// + /// Gets rank. + /// + public int Rank => IsRanked ? Dimensions.Length : throw new InvalidOperationException("Shape is unranked"); + + /// + /// Gets get Total Elements. + /// + public long Size => Enumerable.Range(0, Rank).Aggregate(1L, (size, i) => size * this[i].FixedValue); - /// - /// return new shape after insert dim. - /// - public Shape InsertAndClone(int index, Dimension dim) + /// + public int Count => Operands.Length; + + /// + /// Gets the dimension. + /// + /// Index, allowing negative value. + /// Dimension. + public Dimension this[int index] => index >= 0 ? Dimensions[index] : Dimensions[Rank + index]; + + public static implicit operator ReadOnlySpan(Shape shape) => shape.Select(x => x.FixedValue).ToArray(); + + public static implicit operator Shape(int[] dimensions) => new Shape(dimensions); + + public static implicit operator Shape(long[] dimensions) => new Shape(dimensions); + + public static implicit operator Shape(Dimension[] dimensions) => new Shape(dimensions); + + public static bool operator ==(Shape? lhs, Shape? rhs) + { + return EqualityComparer.Default.Equals(lhs, rhs); + } + + public static bool operator !=(Shape? lhs, Shape? rhs) + { + return !(lhs == rhs); + } + + /// + /// Gets a shape with rank unknwon dimension. + /// + public static Shape Unknown(int rank) + { + return new Shape(Enumerable.Range(0, rank).Select(x => Dimension.Unknown)); + } + + /// + /// Gets a shape with rank unknwon dimension. + /// + public static Shape FromExpr(Expr value) + { + if (value is TensorConst tc) { - var l = _dimensions.ToList(); - l.Insert(index, dim); - return new Shape(l.ToArray()); + return new Shape(tc.Value.ToArray()); } - - /// - /// return new shape after insert dim. - /// - public Shape InsertAndClone(int index, IEnumerable dims) + else if (value is Call { Target: Concat } concat) { - var l = _dimensions.ToList(); - foreach (var d in dims) + if (concat.Arguments[Concat.Input.Index] is Tuple tuple) { - l.Insert(index++, d); + return new Shape(tuple.Fields); } - - return new Shape(l.ToArray()); } - /// - /// convert to the int list. - /// - public List ToValueList() + var shape = value.CheckedShape; + if (shape.Rank != 1 || !shape.IsFixed) { - return _dimensions.Select(dim => dim.FixedValue).ToList(); + // throw new ArgumentException($"Invalid shape expr: {value}", nameof(value)); + return Shape.Unranked; } - /// - /// convert the int array. - /// - public long[] ToValueArray() + var rank = (int)shape[0].FixedValue; + return new Shape(Enumerable.Range(0, rank).Select(x => value[x])); + } + + public IEnumerator GetEnumerator() + { + for (int i = 0; i < Count; i++) { - return _dimensions.Select(dim => dim.FixedValue).ToArray(); + yield return Dimensions[i]; } + } - /// - public override string ToString() => Kind switch - { - ShapeKind.Invalid => "Invalid", - ShapeKind.Unranked => "Unranked", - _ => $"[{string.Join(',', _dimensions)}]", - }; + IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - /// - public int GetHashCode(IEqualityComparer comparer) - { - return ((IStructuralEquatable)_dimensions).GetHashCode(comparer); - } + /// + /// Get Prod. + /// + public Dimension Prod() + { + return Enumerable.Range(0, Rank).Aggregate((Dimension)1L, (size, i) => size * this[i]); + } - /// - public override int GetHashCode() + public FixedAndDynamicDimension ProdFixedAndDynamic() + { + var fixedValue = 1L; + Dimension? dynamicValue = null; + foreach (var dim in this) { - return _hashcode; + if (dim.IsFixed) + { + fixedValue *= dim.FixedValue; + } + else + { + dynamicValue = dynamicValue is null ? dim : dynamicValue * dim; + } } - /// - public IEnumerator GetEnumerator() - { - return ((IEnumerable)_dimensions).GetEnumerator(); - } + return new(fixedValue, dynamicValue); + } + + public long ProdWithDynamicAsOne(long scale = 1) => + Enumerable.Range(0, Rank).Aggregate(scale, (acc, x) => acc * (this[x].IsFixed ? this[x].FixedValue : 1)); - IEnumerator IEnumerable.GetEnumerator() + /// + /// return new shape after insert dim. + /// + public Shape InsertAndClone(int index, Dimension dim) + { + var l = Dimensions.AsValueEnumerable().ToList(); + l.Insert(index, dim.ToExpr()); + return new Shape(l.ToArray()); + } + + /// + /// return new shape after insert dim. + /// + public Shape InsertAndClone(int index, IEnumerable dims) + { + var l = Dimensions.AsValueEnumerable().ToList(); + foreach (var d in dims) { - return ((IEnumerable)_dimensions).GetEnumerator(); + l.Insert(index++, d.ToExpr()); } - /// - public bool Equals(object? other, IEqualityComparer comparer) + return new Shape(l.ToArray()); + } + + /// + /// convert to the int list. + /// + public List ToValueList() + { + return this.Select(x => x.FixedValue).ToList(); + } + + /// + /// convert the int array. + /// + public long[] ToValueArray() + { + return this.Select(x => x.FixedValue).ToArray(); + } + + /// + public override string ToString() => Kind switch + { + ShapeKind.Invalid => "Invalid", + ShapeKind.Unranked => "Unranked", + _ => $"[{StringUtility.Join(',', Dimensions)}]", + }; + + /// + public bool Equals(Shape? other) + { + if (ReferenceEquals(this, other)) { - return other is Shape shape && ((IStructuralEquatable)_dimensions).Equals(shape._dimensions, comparer); + return true; } - /// - public bool Equals(Shape? other) + return other is not null && Dimensions.SequenceEqual(other.Dimensions); + } + + /// + public override bool Equals(object? other) + { + return other is Shape shape && Equals(shape); + } + + public bool IsAssignableFrom(Shape shape) + { + if (IsUnranked) { - return other is not null && StructuralComparisons.StructuralEqualityComparer.Equals(_dimensions, other._dimensions); + return true; } - /// - public override bool Equals(object? other) + if (shape.IsUnranked || Rank != shape.Rank) { - return other is Shape shape && Equals(shape); + return false; } - public bool IsAssignableFrom(Shape shape) + for (int i = 0; i < Dimensions.Length; i++) { - if (IsUnranked) - { - return true; - } - - if (shape.IsUnranked || Rank != shape.Rank) + if (!this[i].IsAssignableFrom(shape[i])) { return false; } + } - for (int i = 0; i < _dimensions.Length; i++) - { - if (!_dimensions[i].IsAssignableFrom(shape[i])) - { - return false; - } - } + return true; + } - return true; - } + public override TExprResult Accept(ExprFunctor functor, TContext context) => + functor.VisitShape(this, context); - public IR.Tuple ToTuple() - { - if (IsUnranked) - { - throw new InvalidOperationException("Cannot convert unranked shape to tuple"); - } + public Shape With(Expr[]? dimensions = null) => new Shape(dimensions ?? Dimensions); - return new IR.Tuple(_dimensions.Select(x => x.ToExpr()).ToArray()); - } + protected override int GetHashCodeCore() + { + return HashCode.Combine(GetType(), Kind, HashCode.Combine(Operands)); + } - private static ShapeKind KindOf(ReadOnlySpan dimensions) - { - return dimensions.AsValueEnumerable().Any(x => x.IsUnknown) ? ShapeKind.HasUnknownDimension : ShapeKind.Fixed; - } + protected override void OnOperandsReplaced() + { + base.OnOperandsReplaced(); + RefreshKind(); + } + + private void RefreshKind() + { + Kind = this.All(x => x.IsFixed) ? ShapeKind.Fixed : ShapeKind.HasUnknownDimension; } } diff --git a/src/Nncase.Core/IR/ShapeConst.cs b/src/Nncase.Core/IR/ShapeConst.cs deleted file mode 100644 index e1f2d4abcf..0000000000 --- a/src/Nncase.Core/IR/ShapeConst.cs +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; -using System.Threading.Tasks; - -namespace Nncase.IR; - -/// -/// Constant of shape. -/// -public sealed class ShapeConst : Const, IEquatable -{ - public ShapeConst(Shape shape) - : base(new TensorType(DataTypes.Int64, new[] { shape.Rank })) - { - Value = shape; - } - - public Shape Value { get; } - - /// - public override string ToString() - { - return Value.ToString(); - } - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitShapeConst(this, context); - - public ShapeConst With(Shape? value = null) - { - return new ShapeConst(value ?? Value); - } - - public bool Equals(ShapeConst? other) => other is ShapeConst o && Value.Equals(o.Value); - - public override bool Equals(object? obj) - { - return Equals(obj as ShapeConst); - } -} - -/// -/// Constant of tensor. -/// -public sealed class DimensionConst : Const, IEquatable -{ - public DimensionConst(Dimension value) - : base(new TensorType(DataTypes.Int64, Shape.Scalar)) - { - Value = value; - } - - public Dimension Value { get; } - - /// - public override string ToString() - { - return Value.ToString(); - } - - /// - public override TExprResult Accept(ExprFunctor functor, TContext context) - => functor.VisitDimensionConst(this, context); - - public DimensionConst With(Dimension? value = null) - { - return new DimensionConst(value ?? Value); - } - - public bool Equals(DimensionConst? other) => other is DimensionConst o && Value.Equals(o.Value); - - public override bool Equals(object? obj) - { - return Equals(obj as DimensionConst); - } -} diff --git a/src/Nncase.Core/IR/TypePattern.cs b/src/Nncase.Core/IR/TypePattern.cs index 54bfe3fb79..e09aa7bd35 100644 --- a/src/Nncase.Core/IR/TypePattern.cs +++ b/src/Nncase.Core/IR/TypePattern.cs @@ -128,7 +128,7 @@ public static TypePattern HasShape(Shape target_shape) => HasShape( inshape => inshape.Rank == target_shape.Rank && inshape.Zip(target_shape).All( - (dim) => dim.Second.IsAny ? true : dim.Second == dim.First), + (dim) => dim.Second.IsAssignableFrom(dim.First)), $"Shape = {target_shape}"); /// diff --git a/src/Nncase.Core/IValue.cs b/src/Nncase.Core/IValue.cs index f58cd1272b..e63297cd9a 100644 --- a/src/Nncase.Core/IValue.cs +++ b/src/Nncase.Core/IValue.cs @@ -98,10 +98,6 @@ public static IValue FromConst(Const @const) return FromTensor(tc.Value); case TupleConst tpc: return tpc.Value; - case ShapeConst spc: - return new ShapeValue(spc.Value.ToArray()); - case DimensionConst dc: - return new DimensionValue(dc.Value); default: throw new ArgumentOutOfRangeException(nameof(@const)); } @@ -259,109 +255,6 @@ public override string ToString() } } -public sealed class DimensionValue : IValue, IEquatable -{ - private readonly Dimension _value; - - public DimensionValue(Dimension value) - { - _value = value; - } - - public IRType Type => new TensorType(DataTypes.Int64, Shape.Scalar); - - public int Count => 0; - - public Dimension Dimension => _value; - - public IValue this[int index] => throw new NotSupportedException("scalar can't index"); - - public Tensor AsTensor() => throw new NotImplementedException(); - - public Tensor[] AsTensors() => throw new NotImplementedException(); - - public bool Equals(DimensionValue? other) => EqualityComparer.Default.Equals(_value, other?._value); - - public override int GetHashCode() => EqualityComparer.Default.GetHashCode(_value); - - public override bool Equals(object? obj) - { - return Equals(obj as DimensionValue); - } - - public IEnumerator GetEnumerator() - { - yield break; - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); -} - -public sealed class ShapeValue : IValue, IEquatable -{ - private readonly Dimension[] _values; - - public ShapeValue(params Dimension[] values) - { - _values = values; - } - - public ShapeValue(IEnumerable values) - { - var dims = new List(); - foreach (var item in values) - { - if (item is DimensionValue dimValue) - { - dims.Add(dimValue.Dimension); - } - else if (item is ShapeValue shapeValue) - { - dims.AddRange(shapeValue._values); - } - else - { - throw new NotSupportedException("only support dimension/shape value for constructor"); - } - } - - _values = dims.ToArray(); - } - - public IRType Type => new TensorType(DataTypes.Int64, new[] { _values.Length }); - - public int Count => _values.Length; - - public Span Dimensions => _values; - - public IValue this[int index] => new DimensionValue(_values[index]); - - public Tensor AsTensor() => throw new NotImplementedException(); - - public Tensor[] AsTensors() => throw new NotImplementedException(); - - public bool Equals(ShapeValue? other) => StructuralComparisons.StructuralEqualityComparer.Equals(_values, other?._values); - - public override int GetHashCode() => StructuralComparisons.StructuralEqualityComparer.GetHashCode(_values); - - public override bool Equals(object? obj) - { - return Equals(obj as ShapeValue); - } - - public IEnumerator GetEnumerator() - { - foreach (var item in _values) - { - yield return new DimensionValue(item); - } - } - - IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - - public override string ToString() => $"[{string.Join(",", _values.AsValueEnumerable().Select(v => v.ToString()))}]"; -} - /// /// Tuple value. /// diff --git a/src/Nncase.Core/PatternMatch/Pattern.Conversions.cs b/src/Nncase.Core/PatternMatch/Pattern.Conversions.cs index 77b7fd5c85..ad67aed154 100644 --- a/src/Nncase.Core/PatternMatch/Pattern.Conversions.cs +++ b/src/Nncase.Core/PatternMatch/Pattern.Conversions.cs @@ -46,6 +46,8 @@ public abstract partial record Pattern public static implicit operator Pattern(int[] span) => Const.FromTensor(Tensor.From(span)); + public static implicit operator Pattern(long[] span) => Const.FromTensor(Tensor.From(span)); + public static implicit operator Pattern(float[] span) => Const.FromTensor(Tensor.From(span)); diff --git a/src/Nncase.Core/TIR/TIRUtilities.cs b/src/Nncase.Core/TIR/TIRUtilities.cs index 6dcad9ad16..9b50ec691c 100644 --- a/src/Nncase.Core/TIR/TIRUtilities.cs +++ b/src/Nncase.Core/TIR/TIRUtilities.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Text; using System.Threading.Tasks; +using Nncase.IR; namespace Nncase.TIR; @@ -20,7 +21,7 @@ public static class TIRUtilities /// public static IReadOnlyList<(IR.Expr Before, IR.Expr After)> ComputePaddings(IReadOnlyList bounds, IR.Shape shape) => bounds.Select((bound, i) => - ((IR.Expr)IR.F.Math.Max(-bound.Start, 0), (IR.Expr)IR.F.Math.Max(bound.Stop - shape[i].FixedValue, 0))).ToArray(); + ((IR.Expr)IR.F.Math.Max(-bound.Start, 0), (IR.Expr)IR.F.Math.Max((bound.Stop - shape[i]).ToExpr(), 0))).ToArray(); /// /// give two bounds compute paddings. diff --git a/src/Nncase.Core/Utilities/DistributedUtility.cs b/src/Nncase.Core/Utilities/DistributedUtility.cs index c2033018bc..d8a0e040eb 100644 --- a/src/Nncase.Core/Utilities/DistributedUtility.cs +++ b/src/Nncase.Core/Utilities/DistributedUtility.cs @@ -15,7 +15,7 @@ public static IReadOnlyList> GetLeafCandidateNDSBPs(TensorType tens for (int i = 0; i < placement.Rank; i++) { var ndsbp = new List(); - if (tensorType.Shape.All(x => x.IsUnknown || x.FixedValue != 0)) + if (tensorType.Shape.All(x => x.IsDynamic || x.FixedValue != 0)) { for (int axis = 0; axis < tensorType.Shape.Rank; axis++) { @@ -227,7 +227,7 @@ public static float GetDividedTensorEfficiency(DistributedType distributedType, return 1f; } - return Enumerable.Range(0, tiles.Count).Select(i => ((int)tiles[i].FixedValue).Ranges(0, (int)shape[i].FixedValue)).CartesianProduct().Select(rgs => + return Enumerable.Range(0, tiles.Rank).Select(i => ((int)tiles[i].FixedValue).Ranges(0, (int)shape[i].FixedValue)).CartesianProduct().Select(rgs => { var slice = rgs.ToArray(); var iscontiguous = TensorUtilities.IsContiguousSlice(shape.ToValueArray(), slice, out var contiguousStart); diff --git a/src/Nncase.Core/Utilities/StringUtility.cs b/src/Nncase.Core/Utilities/StringUtility.cs index e2f81ead2c..0205ec18b9 100644 --- a/src/Nncase.Core/Utilities/StringUtility.cs +++ b/src/Nncase.Core/Utilities/StringUtility.cs @@ -15,6 +15,96 @@ namespace Nncase.Utilities; public static class StringUtility { + public static string Join(char separator, ReadOnlySpan values) + { + var en = values.GetEnumerator(); + if (!en.MoveNext()) + { + return string.Empty; + } + + // We called MoveNext once, so this will be the first item + var currentValue = en.Current; + + // Call ToString before calling MoveNext again, since + // we want to stay consistent with the below loop + // Everything should be called in the order + // MoveNext-Current-ToString, unless further optimizations + // can be made, to avoid breaking changes + string? firstString = currentValue?.ToString(); + + // If there's only 1 item, simply call ToString on that + if (!en.MoveNext()) + { + // We have to handle the case of either currentValue + // or its ToString being null + return firstString ?? string.Empty; + } + + var result = new StringBuilder(); + + result.Append(firstString); + + do + { + currentValue = en.Current; + + result.Append(separator); + if (currentValue != null) + { + result.Append(currentValue.ToString()); + } + } + while (en.MoveNext()); + + return result.ToString(); + } + + public static string Join(ReadOnlySpan separator, ReadOnlySpan values) + { + var en = values.GetEnumerator(); + if (!en.MoveNext()) + { + return string.Empty; + } + + // We called MoveNext once, so this will be the first item + var currentValue = en.Current; + + // Call ToString before calling MoveNext again, since + // we want to stay consistent with the below loop + // Everything should be called in the order + // MoveNext-Current-ToString, unless further optimizations + // can be made, to avoid breaking changes + string? firstString = currentValue?.ToString(); + + // If there's only 1 item, simply call ToString on that + if (!en.MoveNext()) + { + // We have to handle the case of either currentValue + // or its ToString being null + return firstString ?? string.Empty; + } + + var result = new StringBuilder(); + + result.Append(firstString); + + do + { + currentValue = en.Current; + + result.Append(separator); + if (currentValue != null) + { + result.Append(currentValue.ToString()); + } + } + while (en.MoveNext()); + + return result.ToString(); + } + public static string Join(ReadOnlySpan separator, in SpanSelectEnumerable values) where TSelector : struct, IFunction { diff --git a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs index 2edecdc4b0..a068e87fba 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ILPrintVisitor.cs @@ -530,7 +530,39 @@ protected override string VisitGrid(IR.Affine.Grid expr) private string GetNextSSANumber() { - return $"%{_stackedSSANumbers[^1]++}"; + if (_names.TryGetValue(expr, out var name)) + { + return name; + } + + _scope.Push(); + + // 1. Sequential signature + _scope.Append($"Sequential"); + AppendCheckedType(expr.CheckedType, expr.Metadata.Range, " {", hasNewLine: true); + + // 2. For Body + using (_scope.IndentUp()) + { + foreach (var item in expr.Fields) + { + Visit(item); + } + } + + // 3. For closing + _scope.IndWriteLine("}"); + + // 4. extact whole il + _scope.IndWrite(_scope.Pop()); + return string.Empty; + } + + private string AllocateTempVar(Expr expr) + { + var name = $"%{_localId++}"; + _names.Add(expr, name); + return name; } private string VisitShape(Shape shape) => @@ -565,7 +597,7 @@ private string VisitDimensionExpr(Expr expr) private void AppendCheckedType(IRType? type, ValueRange? range, string end = "", bool hasNewLine = true) { - var rangeText = range is not null ? $" [{range.Value.Min}, {range.Value.Max}]" : string.Empty; + var rangeText = range is not null ? $"[{range.Value.Min}, {range.Value.Max}]" : string.Empty; if (type is not null) { if (hasNewLine) diff --git a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs index b49c2b214b..721b748e5f 100644 --- a/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs +++ b/src/Nncase.Diagnostics/Diagnostics/ScriptPrintVisitor.cs @@ -454,7 +454,7 @@ protected override IPrintSymbol VisitVar(Var expr) return doc; } - doc = (ScriptSymobl)_scope.GetUniqueVarSymbol(expr); + doc = (ScriptSymobl)_scope.GetUniqueVarSymbol(expr, "%"); _exprMemo.Add(expr, doc); return doc; } @@ -703,6 +703,18 @@ protected override IPrintSymbol VisitNone(None expr) return doc; } + protected override IPrintSymbol VisitShape(Shape expr) + { + if (_exprMemo.TryGetValue(expr, out var doc)) + { + return doc; + } + + doc = new ScriptSymobl(new("Shape"), "Shape", false); + _exprMemo.Add(expr, doc); + return doc; + } + /// /// indent xxxxxx ( // type_info /// indent indent xxx diff --git a/src/Nncase.EGraph/Passes/EGraphExtractor.cs b/src/Nncase.EGraph/Passes/EGraphExtractor.cs index 42c283042a..31475be437 100644 --- a/src/Nncase.EGraph/Passes/EGraphExtractor.cs +++ b/src/Nncase.EGraph/Passes/EGraphExtractor.cs @@ -621,11 +621,8 @@ public Expr Visit(EClass root) case Marker mk: expr = mk.With(target: children[0], attribute: children[1], metadata: mk.Metadata); break; - case ShapeConst sc: - expr = sc; - break; - case DimensionConst dc: - expr = dc; + case Shape shape: + expr = shape.With(children); break; default: throw new NotSupportedException(enode.Expr.GetType().Name); diff --git a/src/Nncase.EGraph/Passes/EGraphPrinter.cs b/src/Nncase.EGraph/Passes/EGraphPrinter.cs index d574486379..0512b7965c 100644 --- a/src/Nncase.EGraph/Passes/EGraphPrinter.cs +++ b/src/Nncase.EGraph/Passes/EGraphPrinter.cs @@ -258,8 +258,6 @@ protected override string VisitConst(Const expr) { TensorConst tc => tc.Value.Shape.Size <= 8 ? tc.Value.GetArrayString(false) : string.Empty, TupleConst => string.Empty, - ShapeConst sc => VisitShape(sc.Value), - DimensionConst dc => VisitDimension(dc.Value), _ => throw new ArgumentOutOfRangeException(nameof(expr)), }; valueStr = valueStr != string.Empty ? " : " + valueStr : string.Empty; @@ -285,20 +283,20 @@ protected override string VisitOp(Op expr) protected override string VisitNone(None expr) => "None"; - private string VisitShape(Shape shape) => - shape.Kind switch - { - ShapeKind.Invalid => "Invalid", - ShapeKind.Unranked => "Unranked", - _ => $"[{string.Join(',', shape.Select(VisitDimension))}]", - }; + protected override string VisitShape(Shape shape) => + shape.Kind switch + { + ShapeKind.Invalid => "Invalid", + ShapeKind.Unranked => "Unranked", + _ => $"[{string.Join(',', shape.Select(VisitDimension))}]", + }; private string VisitDimension(Dimension dimension) => dimension.Kind switch { - DimensionKind.Any => "any", + DimensionKind.Unknown => "?", DimensionKind.Fixed => dimension.FixedValue.ToString(), - DimensionKind.Unknown => dimension.Value is Var var ? $"%{var.Name}" : "?", + DimensionKind.Dynamic => dimension.Value is Var var ? $"%{var.Name}" : "...", _ => throw new NotSupportedException(dimension.Kind.ToString()), }; } diff --git a/src/Nncase.Evaluator/Buffers/StrideOf.cs b/src/Nncase.Evaluator/Buffers/StrideOf.cs index 1b1702a3a5..cc9f1fae85 100644 --- a/src/Nncase.Evaluator/Buffers/StrideOf.cs +++ b/src/Nncase.Evaluator/Buffers/StrideOf.cs @@ -15,5 +15,5 @@ namespace Nncase.Evaluator.Buffers; [TypeInferGenerator] public partial class StrideOfEvaluator : ITypeInferencer { - private IRType Visit(TensorType input) => new TensorType(DataTypes.Int32, new[] { input.Shape.Count }); + private IRType Visit(TensorType input) => new TensorType(DataTypes.Int32, new[] { input.Shape.Rank }); } diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs index 4561051115..5d32607555 100755 --- a/src/Nncase.Evaluator/Math/Binary.cs +++ b/src/Nncase.Evaluator/Math/Binary.cs @@ -80,61 +80,45 @@ public static IRType CheckSBP(BinaryOp op, TensorType tensorType, DistributedTyp /// public IValue Visit(IEvaluateContext context, Binary binary) { - var lhsValue = context.GetArgumentValue(binary, Binary.Lhs); - var rhsValue = context.GetArgumentValue(binary, Binary.Rhs); - switch (lhsValue, rhsValue) + var lhs = context.GetArgumentValueAsTensor(binary, Binary.Lhs); + var rhs = context.GetArgumentValueAsTensor(binary, Binary.Rhs); + if (lhs.Shape.IsScalar && rhs.Shape.IsScalar) { - case (TensorValue lhsTV, TensorValue rhsTV): - var lhs = lhsTV.AsTensor(); - var rhs = rhsTV.AsTensor(); - if (lhs.Shape.IsScalar && rhs.Shape.IsScalar) - { - if (lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.Int64 && rhs.ElementType == DataTypes.Int64) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.Float32 && rhs.ElementType == DataTypes.Float32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.Boolean && rhs.ElementType == DataTypes.Boolean) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType == DataTypes.UInt32 && rhs.ElementType == DataTypes.UInt32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64)) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType) - { - return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else - { - return Ort_compute(binary, lhs, rhs); - } - } - + if (lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.Int64 && rhs.ElementType == DataTypes.Int64) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.Float32 && rhs.ElementType == DataTypes.Float32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.Boolean && rhs.ElementType == DataTypes.Boolean) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType == DataTypes.UInt32 && rhs.ElementType == DataTypes.UInt32) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if (lhs.ElementType is PointerType && (rhs.ElementType == DataTypes.UInt32 || rhs.ElementType == DataTypes.UInt64)) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else if ((lhs.ElementType == DataTypes.UInt32 || lhs.ElementType == DataTypes.UInt64) && rhs.ElementType is PointerType) + { + return Value.FromTensor(Tensor.FromScalar(Compute(binary.BinaryOp, lhs.ToScalar(), rhs.ToScalar()))); + } + else + { return Ort_compute(binary, lhs, rhs); - case (DimensionValue { Dimension: { IsFixed: true } } lhsDV, TensorValue { Type: TensorType { Shape: { IsScalar: true } } } rhsTV): - return Value.FromTensor(Compute(binary.BinaryOp, lhsDV.Dimension.FixedValue, rhsTV.AsTensor().ToScalar())); - case (TensorValue { Type: TensorType { Shape: { IsScalar: true } } } lhsTV, DimensionValue { Dimension: { IsFixed: true } } rhsDV): - return Value.FromTensor(Compute(binary.BinaryOp, lhsTV.AsTensor().ToScalar(), rhsDV.Dimension.FixedValue)); - case (DimensionValue { Dimension: { IsFixed: true } } lhsDV, DimensionValue { Dimension: { IsFixed: true } } rhsDV): - return Value.FromTensor(Compute(binary.BinaryOp, lhsDV.Dimension.FixedValue, rhsDV.Dimension.FixedValue)); - default: - break; + } } - throw new NotSupportedException($"binary notsupport {lhsValue} {rhsValue}"); + return Ort_compute(binary, lhs, rhs); } /// diff --git a/src/Nncase.Evaluator/Math/Compare.cs b/src/Nncase.Evaluator/Math/Compare.cs index cc8d776daa..161e16ca8a 100644 --- a/src/Nncase.Evaluator/Math/Compare.cs +++ b/src/Nncase.Evaluator/Math/Compare.cs @@ -2,7 +2,6 @@ // Licensed under the Apache license. See LICENSE file in the project root for full license information. using System; -using System.Linq; using Nncase.CostModel; using Nncase.Diagnostics; using Nncase.IR; @@ -69,31 +68,25 @@ public static IRType CheckSBP(TensorType tensorType, DistributedType a, Distribu /// public IValue Visit(IEvaluateContext context, Compare target) { - var lhsValue = context.GetArgumentValue(target, Compare.Lhs); - var rhsValue = context.GetArgumentValue(target, Compare.Rhs); - switch (lhsValue, rhsValue) + var lhs = context.GetArgumentValueAsTensor(target, Compare.Lhs); + var rhs = context.GetArgumentValueAsTensor(target, Compare.Rhs); + if (lhs.Shape.IsScalar && rhs.Shape.IsScalar && lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) { - case (TensorValue lhsTV, TensorValue rhsTV): - var lhs = lhsTV.AsTensor(); - var rhs = rhsTV.AsTensor(); - if (lhs.Shape.IsScalar && rhs.Shape.IsScalar && lhs.ElementType == DataTypes.Int32 && rhs.ElementType == DataTypes.Int32) - { - return Value.FromTensor(Tensor.FromScalar(Compute(target.CompareOp, lhs.ToScalar(), rhs.ToScalar()))); - } - else - { - return Compute(target.CompareOp, lhs.ToOrtTensor(), rhs.ToOrtTensor()); - } - - case (ShapeValue lhsTV, TensorValue rhsTV): - return Value.FromTensor(Tensor.FromArray(lhsTV.Dimensions.ToArray().Zip(rhsTV.AsTensor().ToArray()).Select(p => Compute(target.CompareOp, p.First, p.Second)).ToArray())); - case (TensorValue lhsTV, ShapeValue rhsTV): - return Value.FromTensor(Tensor.FromArray(lhsTV.AsTensor().ToArray().Zip(rhsTV.Dimensions.ToArray()).Select(p => Compute(target.CompareOp, p.First, p.Second)).ToArray())); - default: - break; + return Value.FromTensor(Tensor.FromScalar(Compute(target.CompareOp, lhs.ToScalar(), rhs.ToScalar()))); } - throw new NotSupportedException(); + var a = context.GetOrtArgumentValue(target, Compare.Lhs); + var b = context.GetOrtArgumentValue(target, Compare.Rhs); + return target.CompareOp switch + { + CompareOp.Equal => OrtKI.Equal(a, b).ToValue(), + CompareOp.LowerOrEqual => OrtKI.LessOrEqual(a, b).ToValue(), + CompareOp.GreaterOrEqual => OrtKI.GreaterOrEqual(a, b).ToValue(), + CompareOp.GreaterThan => OrtKI.Greater(a, b).ToValue(), + CompareOp.LowerThan => OrtKI.Less(a, b).ToValue(), + CompareOp.NotEqual => OrtKI.Not(OrtKI.Equal(a, b)).ToValue(), + _ => throw new ArgumentOutOfRangeException(target.CompareOp.ToString()), + }; } /// @@ -163,41 +156,7 @@ public Expr Visit(IShapeEvaluateContext context, Compare target) return ShapeExprUtility.BroadcastShape(lhs, rhs); } - private bool Compute(CompareOp op, Dimension a, long b) - { - return (a, b) switch - { - (Dimension { IsFixed: true } da, _) => Compute(op, da.FixedValue, b), - (Dimension { IsFixed: false } da, _) => Compute(op, da.Value, b), - }; - } - - private bool Compute(CompareOp op, long a, Dimension b) - { - return (a, b) switch - { - (_, Dimension { IsFixed: true } db) => Compute(op, a, db.FixedValue), - (_, Dimension { IsFixed: false } db) => Compute(op, a, db.Value), - }; - } - - private IValue Compute(CompareOp op, OrtKISharp.Tensor a, OrtKISharp.Tensor b) - { - return op switch - { - CompareOp.Equal => OrtKI.Equal(a, b).ToValue(), - CompareOp.LowerOrEqual => OrtKI.LessOrEqual(a, b).ToValue(), - CompareOp.GreaterOrEqual => OrtKI.GreaterOrEqual(a, b).ToValue(), - CompareOp.GreaterThan => OrtKI.Greater(a, b).ToValue(), - CompareOp.LowerThan => OrtKI.Less(a, b).ToValue(), - CompareOp.NotEqual => OrtKI.Not(OrtKI.Equal(a, b)).ToValue(), - _ => throw new ArgumentOutOfRangeException(op.ToString()), - }; - } - - private bool Compute(CompareOp op, T a, T b) - where T : System.Numerics.IEqualityOperators, System.Numerics.IComparisonOperators - => op switch + private bool Compute(CompareOp op, int a, int b) => op switch { CompareOp.Equal => a == b, CompareOp.LowerOrEqual => a <= b, @@ -208,20 +167,6 @@ private bool Compute(CompareOp op, T a, T b) _ => throw new ArgumentOutOfRangeException(nameof(op)), }; - private bool Compute(CompareOp op, Expr a, long b) - => (op, a, b) switch - { - (CompareOp.Equal, Expr, -1) => false, - _ => throw new ArgumentOutOfRangeException(nameof(op)), - }; - - private bool Compute(CompareOp op, long a, Expr b) - => (op, a, b) switch - { - (CompareOp.Equal, -1, Expr) => false, - _ => throw new ArgumentOutOfRangeException(nameof(op)), - }; - private IRType Visit(TensorType lhs, TensorType rhs) { var broadcastType = TypeInference.BroadcastType(lhs, rhs); diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 817714bde3..5fe4e42815 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -214,7 +214,7 @@ public static IRType VisitTensorType(TensorType lhs, TensorType rhs, bool packin var rhsShape = lhs.Shape.Rank <= rhs.Shape.Rank ? rhs.Shape.ToArray() : Enumerable.Repeat((Dimension)1, lhs.Shape.Rank - rhs.Shape.Rank).Concat(rhs.Shape).ToArray(); var bigShape = Enumerable.Zip(lhsShape, rhsShape).SkipLast(2).Select(t => - t.First.IsUnknown || t.Second.IsUnknown + t.First.IsDynamic || t.Second.IsDynamic ? (Dimension)IR.F.Math.Max(t.First.Value, t.Second.Value) : System.Math.Max(t.First.FixedValue, t.Second.FixedValue)).ToArray(); diff --git a/src/Nncase.Evaluator/NN/Conv2DTranspose.cs b/src/Nncase.Evaluator/NN/Conv2DTranspose.cs index b17a785866..1bea8f1a62 100644 --- a/src/Nncase.Evaluator/NN/Conv2DTranspose.cs +++ b/src/Nncase.Evaluator/NN/Conv2DTranspose.cs @@ -120,7 +120,7 @@ public IRType Visit(ITypeInferenceContext context, Conv2DTranspose target) var input = context.CheckArgumentType(target, Conv2DTranspose.Input); if (context.GetArgument(target, Conv2DTranspose.OutputShape) is TensorConst outShapeValue) { - return new TensorType(input.DType, new Shape(outShapeValue.Value.Cast())); + return new TensorType(input.DType, new Shape(outShapeValue.Value.ToArray())); } else { diff --git a/src/Nncase.Evaluator/Random/Normal.cs b/src/Nncase.Evaluator/Random/Normal.cs index 89381a8c1f..4e2faeee04 100644 --- a/src/Nncase.Evaluator/Random/Normal.cs +++ b/src/Nncase.Evaluator/Random/Normal.cs @@ -38,7 +38,7 @@ public IRType Visit(ITypeInferenceContext context, Normal target) { if (context.GetArgument(target, Normal.Shape) is TensorConst shapeValue) { - return new TensorType(target.Type, new Shape(shapeValue.Value.Cast())); + return new TensorType(target.Type, new Shape(shapeValue.Value.ToArray())); } else { diff --git a/src/Nncase.Evaluator/Random/Uniform.cs b/src/Nncase.Evaluator/Random/Uniform.cs index 17908d4884..95fd73ae52 100644 --- a/src/Nncase.Evaluator/Random/Uniform.cs +++ b/src/Nncase.Evaluator/Random/Uniform.cs @@ -32,7 +32,7 @@ public IRType Visit(ITypeInferenceContext context, Uniform target) { if (context.GetArgument(target, Uniform.Shape) is TensorConst shapeValue) { - return new TensorType(target.Type, new Shape(shapeValue.Value.Cast())); + return new TensorType(target.Type, new Shape(shapeValue.Value.ToArray())); } else { diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs index 3c74ad6755..74c5276358 100644 --- a/src/Nncase.Evaluator/Tensors/Concat.cs +++ b/src/Nncase.Evaluator/Tensors/Concat.cs @@ -23,45 +23,11 @@ public class ConcatEvaluator : IEvaluator, ITypeInferencer, ICos IShapeEvaluator, IMetricEvaluator { /// - public IValue Visit(IEvaluateContext context, Concat target) + public IValue Visit(IEvaluateContext context, Concat cat) { - var inputValue = context.GetArgumentValue(target, Concat.Input); - var axis = target.Axis; - switch (inputValue) - { - case TupleValue tpv: - if (tpv.All(v => v is TensorValue)) - { - var inputs = tpv.AsTensors(); - return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue(); - } - else if (tpv.Any(v => v is ShapeValue)) - { - var dims = new List(); - foreach (var fv in tpv) - { - switch (fv) - { - case TensorValue ftv: - dims.Add(new Dimension(ftv.AsTensor().Cast()[axis])); - break; - case ShapeValue fsv: - dims.Add(fsv.Dimensions[axis]); - break; - default: - throw new ArgumentOutOfRangeException(nameof(target), "ShapeValue's field not support"); - } - } - - return new ShapeValue(dims.ToArray()); - } - - break; - default: - break; - } - - throw new ArgumentOutOfRangeException(nameof(target)); + var inputs = context.GetArgumentValueAsTensors(cat, Concat.Input); + var axis = cat.Axis; + return OrtKI.Concat(inputs.Select(t => t.ToOrtTensor()).ToArray(), axis).ToValue(); } /// @@ -258,6 +224,6 @@ private Dimension AxisDim(TupleType inputs, int axisValue) { return inputs.Fields.Aggregate( (Dimension)0, - (prod, next) => prod + (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue].Value); + (prod, next) => prod + (next switch { TensorType t => t, DistributedType d => d.TensorType, _ => throw new NotSupportedException() }).Shape[axisValue]); } } diff --git a/src/Nncase.Evaluator/Tensors/Flatten.cs b/src/Nncase.Evaluator/Tensors/Flatten.cs index c3eb97a623..78804d3e6b 100644 --- a/src/Nncase.Evaluator/Tensors/Flatten.cs +++ b/src/Nncase.Evaluator/Tensors/Flatten.cs @@ -46,7 +46,7 @@ private IRType Visit(ITypeInferenceContext context, Flatten target, TensorType i { var axisValue = Util.PositiveIndex(axisV.Value.ToScalar(), input); var first = input.Shape.Take(axisValue).Aggregate((Dimension)1, (x, y) => x * y); - var second = input.Shape.Take(axisValue..input.Shape.Count).Aggregate((Dimension)1, (x, y) => x * y); + var second = input.Shape.Take(axisValue..input.Shape.Rank).Aggregate((Dimension)1, (x, y) => x * y); return input with { Shape = new[] { first, second } }; } diff --git a/src/Nncase.Evaluator/Tensors/Gather.cs b/src/Nncase.Evaluator/Tensors/Gather.cs index 2d6e9e0965..3aebcbd9b4 100644 --- a/src/Nncase.Evaluator/Tensors/Gather.cs +++ b/src/Nncase.Evaluator/Tensors/Gather.cs @@ -22,33 +22,10 @@ public class GatherEvaluator : IEvaluator, ITypeInferencer, ICos /// public IValue Visit(IEvaluateContext context, Gather gather) { - var inputValue = context.GetArgumentValue(gather, Gather.Input); - var indexValue = context.GetArgumentValue(gather, Gather.Index); + var input = context.GetOrtArgumentValue(gather, Gather.Input); var axis = gather.Axis; - switch (inputValue, indexValue) - { - case (_, TensorValue indexTValue): - if (inputValue is TensorValue inputTValue) - { - return OrtKI.Gather(inputTValue.AsTensor().ToOrtTensor(), indexTValue.AsTensor().ToOrtTensor(), axis).ToValue(); - } - else if (inputValue is ShapeValue inputSValue && axis == 0) - { - var indexTensor = indexTValue.AsTensor(); - if (!indexTensor.Shape.IsScalar) - { - throw new NotSupportedException("Gather ShapeConst the index must be scalar!"); - } - - return inputSValue[indexTensor.ToScalar()]; - } - - break; - default: - break; - } - - throw new NotSupportedException(); + var index = context.GetOrtArgumentValue(gather, Gather.Index); + return OrtKI.Gather(input, index, axis).ToValue(); } /// diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index af916539c6..c139958a66 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -94,7 +94,7 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i } var rank = (int)shapeType.Shape[0].FixedValue; - var shapeDims = new Shape(Enumerable.Range(0, rank).Select(i => (Dimension)shape[i]).ToArray()); + var shapeDims = new Shape(Enumerable.Range(0, rank).Select(i => shape[i]).ToArray()); var outputShape = new Dimension[rank]; // todo use egraph simplify. diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs index 93b58c7eac..dd4c05c142 100644 --- a/src/Nncase.Evaluator/Tensors/Slice.cs +++ b/src/Nncase.Evaluator/Tensors/Slice.cs @@ -30,55 +30,13 @@ public class SliceEvaluator : IEvaluator, ITypeInferencer, ICostEv /// public IValue Visit(IEvaluateContext context, Slice sl) { - var inputValue = context.GetArgumentValue(sl, Slice.Input); - var beginsValue = context.GetArgumentValue(sl, Slice.Begins); - var endsValue = context.GetArgumentValue(sl, Slice.Ends); - var axesValue = context.GetArgumentValue(sl, Slice.Axes); - var stridesValue = context.GetArgumentValue(sl, Slice.Strides); - switch (inputValue, beginsValue, endsValue, axesValue, stridesValue) - { - case (_, TensorValue beginsTValue, TensorValue endsTValue, TensorValue axesTValue, TensorValue stridesTValue): - var beginsTensor = beginsTValue.AsTensor(); - var endsTensor = endsTValue.AsTensor(); - var axesTensor = axesTValue.AsTensor(); - var stridesTensor = stridesTValue.AsTensor(); - if (inputValue is ShapeValue inputSValue && beginsTensor.Shape.Rank == 1 && endsTensor.Shape.Rank == 1 && axesTensor.Shape.Rank == 1 && stridesTensor.Shape.Rank == 1) - { - // var input = inputShapeValue.AsTensor().Cast().ToOrtTensor(); - var begins = beginsTensor.ToScalar(); - var ends = endsTensor.ToScalar(); - var axes = axesTensor.ToScalar(); - if (axes != 0) - { - throw new NotSupportedException("slice ShapeConst Axes != 0"); - } - - var strides = stridesTensor.ToScalar(); - var sliced = new List(); - for (long i = begins; i < ends; i += strides) - { - sliced.Add(inputSValue[checked((int)i)]); - } - - return new ShapeValue(sliced); - } - else if (inputValue is TensorValue inputTValue) - { - var input = inputTValue.AsTensor().ToOrtTensor(); - var begins = beginsTensor.Cast().ToOrtTensor(); - var ends = endsTensor.Cast().ToOrtTensor(); - var axes = axesTensor.Cast().ToOrtTensor(); - var strides = stridesTensor.Cast().ToOrtTensor(); - var sliced = OrtKI.Slice(input, begins, ends, axes, strides); - return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType)); - } - - break; - default: - break; - } - - throw new NotSupportedException("input value is neither shapevalue or tensorvalue"); + var input = context.GetOrtArgumentValue(sl, Slice.Input); + var begins = context.GetInt64OrtTensorArgumentValue(sl, Slice.Begins); + var ends = context.GetInt64OrtTensorArgumentValue(sl, Slice.Ends); + var axes = context.GetInt64OrtTensorArgumentValue(sl, Slice.Axes); + var strides = context.GetInt64OrtTensorArgumentValue(sl, Slice.Strides); + var sliced = OrtKI.Slice(input, begins, ends, axes, strides); + return Value.FromTensor(context.CurrentCall.CheckedType is AnyType ? sliced.ToTensor() : sliced.ToTensor(context.CurrentCall.CheckedTensorType)); } /// @@ -188,7 +146,7 @@ private static Dimension TranslateBeginEnd(Dimension x, Dimension dim, long lowe else { return ShapeExprUtility.If( - x.Value < 0, + x.Value < 0L, (x, dim) => dim + x, (x, dim) => Clamp(x, lowerBound, dim + upperBoundBias), x.Value, @@ -241,7 +199,7 @@ private IRType Visit(ITypeInferenceContext context, Slice target, TensorType inp // while for negative stepping it is clamped to [-1, dims[axes[i]]-1]. var end = TranslateBeginEnd(ends[i], inDim, -1, -1); - return Dimension.CeilDiv(Dimension.Abs(end - begin), Dimension.Abs(stride)); + return Dimension.CeilDiv(end - begin, Dimension.Abs(stride)); } else { @@ -250,7 +208,7 @@ private IRType Visit(ITypeInferenceContext context, Slice target, TensorType inp // end[i] is clamped into the range [0, dims[axes[i]]] var end = TranslateBeginEnd(ends[i], inDim, 0, 0); - return Dimension.CeilDiv(Dimension.Abs(end - begin), Dimension.Abs(stride)); + return Dimension.CeilDiv(end - begin, Dimension.Abs(stride)); } }); return input with { Shape = outShape }; diff --git a/src/Nncase.Evaluator/Tensors/Split.cs b/src/Nncase.Evaluator/Tensors/Split.cs index 60179cd2e4..e576b7e714 100644 --- a/src/Nncase.Evaluator/Tensors/Split.cs +++ b/src/Nncase.Evaluator/Tensors/Split.cs @@ -111,11 +111,11 @@ private IRType Visit(ITypeInferenceContext context, Split target, TensorType inp if (context.GetArgument(target, Split.Axis) is TensorConst axisCon) { var axisV = Util.PositiveIndex(axisCon.Value.ToScalar(), input.Shape.Rank); - splitedShape[axisV] = Dimension.Unknown(); + splitedShape[axisV] = Dimension.Unknown; } else { - splitedShape = splitedShape.Select(s => Dimension.Unknown()).ToArray(); + splitedShape = Enumerable.Repeat(Dimension.Unknown, splitedShape.Rank).ToArray(); } // return new TupleType(new IRType[] { input with { Shape = splitedShape } }, true); diff --git a/src/Nncase.Evaluator/Tensors/Tile.cs b/src/Nncase.Evaluator/Tensors/Tile.cs index a12469331f..4091ff279d 100644 --- a/src/Nncase.Evaluator/Tensors/Tile.cs +++ b/src/Nncase.Evaluator/Tensors/Tile.cs @@ -71,7 +71,7 @@ private IRType Visit(ITypeInferenceContext context, Tile target, TensorType inpu } else { - var shape = input.Shape.Select((p, i) => p * repeats[i]); + var shape = input.Shape.Select((p, i) => p * (Dimension)repeats[i]); return input with { Shape = new Shape(shape) }; } } diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs index c87c93d0e3..76e1c7ad5b 100644 --- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs +++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs @@ -19,35 +19,9 @@ public class UnsqueezeEvaluator : IEvaluator, ITypeInferencer public IValue Visit(IEvaluateContext context, Unsqueeze unSqueeze) { - var inputValue = context.GetArgumentValue(unSqueeze, Unsqueeze.Input); - var axesValue = context.GetArgumentValue(unSqueeze, Unsqueeze.Dim); - - switch (inputValue, axesValue) - { - case (_, TensorValue axesTValue): - var axesTensor = axesTValue.AsTensor(); - if (inputValue is TensorValue inputTValue) - { - var input = inputTValue.AsTensor().ToOrtTensor(); - var axes = axesTensor.Cast().ToOrtTensor(); - return Value.FromTensor(OrtKI.Unsqueeze(input, axes).ToTensor(context.CurrentCall.CheckedTensorType)); - } - else if (inputValue is DimensionValue inputDValue) - { - if (axesTensor.Shape.Rank > 1 || axesTensor.ToScalar() != 0) - { - throw new NotSupportedException("only support scalar dim when input is DimensionValue!"); - } - - return new ShapeValue(new[] { inputDValue }); - } - - break; - default: - break; - } - - throw new NotSupportedException(); + var input = context.GetOrtArgumentValue(unSqueeze, Unsqueeze.Input); + var axes = context.GetInt64OrtTensorArgumentValue(unSqueeze, Unsqueeze.Dim); + return Value.FromTensor(OrtKI.Unsqueeze(input, axes).ToTensor(context.CurrentCall.CheckedTensorType)); } /// diff --git a/src/Nncase.Evaluator/Tensors/Where.cs b/src/Nncase.Evaluator/Tensors/Where.cs index 7258d271e4..26223b18db 100644 --- a/src/Nncase.Evaluator/Tensors/Where.cs +++ b/src/Nncase.Evaluator/Tensors/Where.cs @@ -23,41 +23,24 @@ public class WhereEvaluator : IEvaluator, ITypeInferencer, ICostEv /// public IValue Visit(IEvaluateContext context, Where where) { - var condValue = context.GetArgumentValue(where, Where.Cond); - var xValue = context.GetArgumentValue(where, Where.X); - var yValue = context.GetArgumentValue(where, Where.Y); - switch (condValue, xValue, yValue) + var xt = context.GetArgumentValueAsTensor(where, Where.X); + var yt = context.GetArgumentValueAsTensor(where, Where.Y); + if (where.IsTfWhere) { - case (TensorValue condTV, TensorValue xTV, _): - if (yValue is TensorValue yTV) - { - if (where.IsTfWhere) - { - var condTensor = condTV.AsTensor().Cast(); - if (condTensor.Rank > 1) - { - throw new NotImplementedException(); - } - - var result = condTensor.Select((b, i) => (b, i)).Where(t => t.b).Select(t => (long)t.i).ToArray(); - return Value.FromTensor(Tensor.From(result, new Shape(result.Length, condTensor.Rank))); - } - else - { - return OrtKI.Where(condTV.AsTensor().ToOrtTensor(), xTV.AsTensor().ToOrtTensor(), yTV.AsTensor().ToOrtTensor()).ToValue(); - } - } - else if (yValue is ShapeValue ySV) - { - return new ShapeValue(condTV.AsTensor().Cast().Zip(xValue.AsTensor().Cast().Zip(ySV.Dimensions.ToArray())).Select(tp => tp.First ? new Dimension(tp.Second.First) : tp.Second.Second).ToArray()); - } - - break; - default: - break; + var condTensor = context.GetArgumentValueAsTensor(where, Where.Cond); + if (condTensor.Rank > 1) + { + throw new NotImplementedException(); + } + + var result = condTensor.Select((b, i) => (b, i)).Where(t => t.b).Select(t => (long)t.i).ToArray(); + return Value.FromTensor(Tensor.From(result, new Shape(result.Length, condTensor.Rank))); } - throw new NotSupportedException(); + var cond = context.GetOrtArgumentValue(where, Where.Cond); + var x = context.GetOrtArgumentValue(where, Where.X); + var y = context.GetOrtArgumentValue(where, Where.Y); + return OrtKI.Where(cond, x, y).ToValue(); } /// @@ -79,7 +62,7 @@ public IRType Visit(TensorType cond, TensorType x, TensorType y, Where target) { if (target.IsTfWhere) { - return new TensorType(DataTypes.Int64, new Shape(Dimension.Unknown(), cond.Shape.Rank)); + return new TensorType(DataTypes.Int64, Shape.Unknown(cond.Shape.Rank)); } return TypeInference.BroadcastType(x.DType, cond, x, y); diff --git a/src/Nncase.Evaluator/TypeInference.cs b/src/Nncase.Evaluator/TypeInference.cs index 7cd27a6e2a..f2f935853e 100644 --- a/src/Nncase.Evaluator/TypeInference.cs +++ b/src/Nncase.Evaluator/TypeInference.cs @@ -150,7 +150,7 @@ public static IRType BroadcastType(DataType dataType, params TensorType[] inputs inputDims[i] = inDim; } - var non1Dims = inputDims.Where(x => x.IsUnknown || x.FixedValue != 1).ToHashSet(); + var non1Dims = inputDims.Where(x => x.IsDynamic || x.IsUnknown || x.FixedValue != 1).ToHashSet(); if (non1Dims.Count == 0) { outputShape[dimIndex] = 1; @@ -420,7 +420,7 @@ public static IRType TransposeType(TensorType input, Expr perm) } var permt = permValue.Value.ToArray(); - if (input.Shape.Count != permt.Length) + if (input.Shape.Rank != permt.Length) { return new InvalidType("Transpose shoud perm.size == inShape.size"); } @@ -505,6 +505,11 @@ IRType CommonTypeImpl(TensorType a, TensorType b) return new InvalidType($"Inputs DType of if should be same, then: {a.DType}, else: {b.DType}"); } + if (a.Shape.IsUnranked || b.Shape.IsUnranked || a.Shape.Rank != b.Shape.Rank) + { + return new TensorType(a.DType, Shape.Unranked); + } + return new TensorType(a.DType, Shape.Unknown(a.Shape.Rank)); } diff --git a/src/Nncase.Evaluator/TypeInferenceVisitor.cs b/src/Nncase.Evaluator/TypeInferenceVisitor.cs index 745fe06522..3f1131f665 100644 --- a/src/Nncase.Evaluator/TypeInferenceVisitor.cs +++ b/src/Nncase.Evaluator/TypeInferenceVisitor.cs @@ -238,6 +238,11 @@ protected override IRType VisitLeafPrimFunctionWrapper(PrimFunctionWrapper expr) return type; } + protected override IRType VisitLeafShape(Shape expr) + { + return NoneType.Default; + } + protected override IRType VisitLeafGrid(Grid expr) { VerifySubField(expr, expr.DomainParameter); diff --git a/src/Nncase.Graph/Transform/DataFlowRewriter.cs b/src/Nncase.Graph/Transform/DataFlowRewriter.cs index f367c594ca..ee101ca986 100644 --- a/src/Nncase.Graph/Transform/DataFlowRewriter.cs +++ b/src/Nncase.Graph/Transform/DataFlowRewriter.cs @@ -26,6 +26,7 @@ internal sealed class DataFlowRewriter : ExprRewriter private readonly HashSet _dontInheritExprs = new HashSet(ReferenceEqualityComparer.Instance); public DataFlowRewriter(IRewriteRule rule, RunPassContext options) + : base(visitAttributes: true) { _rule = rule; _options = options; diff --git a/src/Nncase.Importer/Onnx/DataGatter.cs b/src/Nncase.Importer/Onnx/DataGatter.cs index a4424850b8..a85c6b0912 100644 --- a/src/Nncase.Importer/Onnx/DataGatter.cs +++ b/src/Nncase.Importer/Onnx/DataGatter.cs @@ -31,12 +31,7 @@ public sealed partial class OnnxImporter { TensorProto.Types.DataType.Uint8, DataTypes.UInt8 }, }; - public Shape GetShape(ValueInfoProto v) - { - var shape = v.Type.TensorType.Shape.Dim; - var dimArr = GetDimArray(shape, d => d, d => _dynVarMap[d.DimParam], d => (Dimension)d.DimValue); - return new Shape(dimArr); - } + public Shape GetShape(ValueInfoProto v) => new Shape(GetOriginShape(v)); public Expr[] GetOriginShape(ValueInfoProto v) { diff --git a/src/Nncase.Importer/Onnx/Slice.cs b/src/Nncase.Importer/Onnx/Slice.cs index e3b8a4a858..66788cd66c 100644 --- a/src/Nncase.Importer/Onnx/Slice.cs +++ b/src/Nncase.Importer/Onnx/Slice.cs @@ -44,7 +44,7 @@ private Expr SliceV10(in NodeProto op) // steps.size should eq starts.size starts.InferenceType(); var axes = GetOptionInputExpr(op, 3).Or(ComputeDefaultAxes(input)); - var steps = GetOptionInputExpr(op, 4).Or(Expand(1, starts.CheckedShape)); + var steps = GetOptionInputExpr(op, 4).Or(Expand(1, starts.CheckedShape.ToValueArray())); return Slice(input, starts, ends, axes, steps); } diff --git a/src/Nncase.Importer/TFLite/TFLiteImporter.cs b/src/Nncase.Importer/TFLite/TFLiteImporter.cs index bc190e5553..8a6343933e 100644 --- a/src/Nncase.Importer/TFLite/TFLiteImporter.cs +++ b/src/Nncase.Importer/TFLite/TFLiteImporter.cs @@ -156,7 +156,7 @@ private static Dimension[] GetShapeArray(tflite.Tensor tensor) } return Enumerable.Range(0, tensor.ShapeLength).Select(i => - tensor.ShapeSignature(i) < 0 ? Dimension.Unknown() : tensor.Shape(i)).ToArray(); + tensor.ShapeSignature(i) < 0 ? Dimension.Unknown : tensor.Shape(i)).ToArray(); } private void Visit(in tflite.Operator op) diff --git a/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs b/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs index 67b2a9dc18..1d6f507d3e 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldConstant.cs @@ -48,17 +48,10 @@ private Const GetReplace(Call call, IReadOnlyList constArgs) public partial class FoldShapeOf : RewriteRule { /// - public override CallPattern Pattern { get; } = IsShapeOf(IsWildcard("wc") with { TypePattern = HasRank() }); + public override CallPattern Pattern { get; } = IsShapeOf(IsWildcard("wc") with { TypePattern = HasFixedShape() }); private Const GetReplace(Expr wc) { - if (wc.CheckedShape.IsFixed) - { - return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray()); - } - else - { - return new ShapeConst(wc.CheckedShape); - } + return Const.FromTensor(wc.CheckedShape.ToValueArray().Select(x => (long)x).ToArray()); } } diff --git a/src/Nncase.Passes/Rules/Neutral/FoldGetItemReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldGetItemReshape.cs new file mode 100644 index 0000000000..dc1b330a18 --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/FoldGetItemReshape.cs @@ -0,0 +1,26 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using GetItem = Nncase.IR.Tensors.GetItem; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public partial class FoldGetItemReshape : RewriteRule +{ + public override Pattern Pattern => IsGetItem(null, "getItem", ReshapePattern, new long[] { 0 }); + + public Pattern ReshapePattern => IsReshape(IsWildcard("input"), new long[] { 1 }); + + private Expr? GetReplace(Expr input, int index) + { + return input; + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/UnSqueezeToReshape.cs b/src/Nncase.Passes/Rules/Neutral/UnSqueezeToReshape.cs index 8d0adde45e..2a8d22a604 100644 --- a/src/Nncase.Passes/Rules/Neutral/UnSqueezeToReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/UnSqueezeToReshape.cs @@ -49,6 +49,6 @@ public sealed partial class UnSqueezeToReshape : IRewriteRule } } - return Reshape(input, new Shape(newShape)); + return Reshape(input, new Shape(newShape).ToValueArray()); } } diff --git a/src/Nncase.Passes/SimplifyProvider.cs b/src/Nncase.Passes/SimplifyProvider.cs index d49d97d1fb..fc8a7dea8b 100644 --- a/src/Nncase.Passes/SimplifyProvider.cs +++ b/src/Nncase.Passes/SimplifyProvider.cs @@ -60,6 +60,7 @@ public SimplifyProvider() new GatherToGetItem(), new FoldGetItemShapeOf(), new FoldGetItemConcat(), + new FoldGetItemReshape(), new FoldSplitShapeOf(), ]; } diff --git a/src/Nncase.Passes/Transforms/InferRangePass.cs b/src/Nncase.Passes/Transforms/InferRangePass.cs index 24e4f2a004..6d8513b641 100644 --- a/src/Nncase.Passes/Transforms/InferRangePass.cs +++ b/src/Nncase.Passes/Transforms/InferRangePass.cs @@ -40,9 +40,12 @@ protected override Task RunCoreAsync(BaseFunction pre, RunPassCont internal sealed class InferRangeVisitor : ExprVisitor, Unit> { public InferRangeVisitor() + : base(visitAttributes: true) { } + public override Unit DefaultVisitTypeLeaf(IRType type, Unit context) => default; + protected override ValueRange DispatchVisit(Expr expr) { if (expr.Metadata.Range is null) @@ -64,7 +67,7 @@ protected override ValueRange VisitLeafCall(Call expr) var range = expr.Target switch { Op op => InferenceOp(op, expr), - BaseFunction func => ValueRange.Full, + BaseFunction => ValueRange.Full, _ => ValueRange.Full, }; return range; @@ -76,6 +79,17 @@ protected override ValueRange VisitLeafTensorConst(TensorConst expr) return new ValueRange(value.Min(), value.Max()); } + protected override ValueRange VisitLeafShape(Shape expr) + { + if (!expr.Any()) + { + return ValueRange.Full; + } + + var ranges = expr.Select(x => x.IsFixed ? new ValueRange(x.FixedValue, x.FixedValue) : Visit(x.Value)).ToArray(); + return new ValueRange(ranges.Min(x => x.Min), ranges.Max(x => x.Max)); + } + protected override ValueRange VisitLeafTuple(IR.Tuple expr) { var ranges = expr.Fields.AsValueEnumerable().Select(Visit).ToArray(); @@ -86,12 +100,15 @@ private ValueRange InferenceOp(Op op, Call expr) { return op switch { - Reshape => expr[Reshape.Input].Metadata.Range!.Value, - Slice => expr[Slice.Input].Metadata.Range!.Value, - Gather => expr[Gather.Input].Metadata.Range!.Value, - GetItem => expr[GetItem.Input].Metadata.Range!.Value, - Concat => expr[Concat.Input].Metadata.Range!.Value, + Reshape => Visit(expr[Reshape.Input]), + Slice => Visit(expr[Slice.Input]), + Gather => Visit(expr[Gather.Input]), + GetItem => Visit(expr[GetItem.Input]), + Concat => Visit(expr[Concat.Input]), Binary binary => InferenceBinary(expr, binary.BinaryOp), + Squeeze => Visit(expr[Squeeze.Input]), + Unsqueeze => Visit(expr[Unsqueeze.Input]), + Stack => Visit(expr[Stack.Inputs]), _ => ValueRange.Full, }; } @@ -105,10 +122,40 @@ private ValueRange InferenceBinary(Call expr, BinaryOp op) { BinaryOp.Add => new(lhs.Min + rhs.Min, lhs.Max + rhs.Max), BinaryOp.Sub => new(lhs.Min - rhs.Max, lhs.Max - rhs.Min), - BinaryOp.Mul => new(lhs.Min * rhs.Min, lhs.Max * rhs.Max), + BinaryOp.Mul => VisitMul(lhs, rhs), + BinaryOp.Div => VisitDiv(lhs, rhs), BinaryOp.Max => new(Math.Max(lhs.Min, rhs.Min), Math.Max(lhs.Max, rhs.Max)), BinaryOp.Min => new(Math.Min(lhs.Min, rhs.Min), Math.Min(lhs.Max, rhs.Max)), _ => ValueRange.Full, }; } + + private ValueRange VisitDiv(ValueRange lhs, ValueRange rhs) + { + if (rhs.Min <= 0 && rhs.Max >= 0) + { + return ValueRange.Full; + } + + var values = new[] + { + lhs.Min / rhs.Min, + lhs.Min / rhs.Max, + lhs.Max / rhs.Min, + lhs.Max / rhs.Max, + }; + return new ValueRange(values.Min(), values.Max()); + } + + private ValueRange VisitMul(ValueRange lhs, ValueRange rhs) + { + var values = new[] + { + lhs.Min * rhs.Min, + lhs.Min * rhs.Max, + lhs.Max * rhs.Min, + lhs.Max * rhs.Max, + }; + return new ValueRange(values.Min(), values.Max()); + } } diff --git a/src/Nncase.Tests/Core/IR/UnitTestDimension.cs b/src/Nncase.Tests/Core/IR/UnitTestDimension.cs index 225eedd4a3..ae8bb9146b 100644 --- a/src/Nncase.Tests/Core/IR/UnitTestDimension.cs +++ b/src/Nncase.Tests/Core/IR/UnitTestDimension.cs @@ -34,7 +34,7 @@ public void TestKind() Assert.False(d1.IsUnknown); Assert.True(d1.IsFixed); - var d2 = Dimension.Unknown(); + var d2 = Dimension.Unknown; Assert.Equal(DimensionKind.Unknown, d2.Kind); Assert.True(d2.IsUnknown); Assert.False(d2.IsFixed); @@ -67,7 +67,7 @@ public void TestOperatorAdd() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - var d3 = Dimension.Unknown(); + var d3 = Dimension.Unknown; var d4 = d1 + d2; Assert.Equal(v1 + v2, d4.Value); @@ -88,7 +88,7 @@ public void TestOperatorSubtract() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - var d3 = Dimension.Unknown(); + var d3 = Dimension.Unknown; var d4 = d1 - d2; Assert.Equal(v1 - v2, d4.Value); @@ -105,7 +105,7 @@ public void TestOperatorMul() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - var d3 = Dimension.Unknown(); + var d3 = Dimension.Unknown; var d4 = d1 * d2; Assert.Equal(v1 * v2, d4.Value); @@ -122,7 +122,7 @@ public void TestOperatorDiv() var v2 = 1; Dimension d1 = v1; Dimension d2 = v2; - var d3 = Dimension.Unknown(); + var d3 = Dimension.Unknown; var d4 = d1 / d2; Assert.Equal(v1 / v2, d4.Value); diff --git a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs index d04efd5857..40a6e5a57d 100644 --- a/src/Nncase.Tests/Core/UnitTestTypeInfer.cs +++ b/src/Nncase.Tests/Core/UnitTestTypeInfer.cs @@ -319,7 +319,7 @@ public void TestBroadcastInfer() [Fact] public void TestBroadcastInfer2() { - var dimUnk1 = Dimension.Unknown(); + var dimUnk1 = Dimension.Unknown; var a = new TensorType(DataTypes.Float32, new Dimension[] { 1, dimUnk1, 8192 }); var b = new TensorType(DataTypes.Float32, new Dimension[] { 1 }); var result = TypeInference.BroadcastType(a, b); @@ -330,17 +330,17 @@ public void TestBroadcastInfer2() public void TestReshapeInfer() { var dimVar = new Var("seq_len", new TensorType(DataTypes.Int64, Shape.Scalar)); - var dimC = new Dimension(dimVar); - var a = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 128 })); - var constShape = new ShapeConst(new[] { 1, dimC, 2, 64 }); + var dimC = (Dimension)dimVar; + var a = new Var(new TensorType(DataTypes.Float32, new Shape(1, dimVar, 128))); + var constShape = new Shape(1, dimC, 2, 64); var reshape = Reshape(a, constShape); var result = reshape.CheckedType; - Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 2, 64 }), result); + Assert.Equal(new TensorType(DataTypes.Float32, new Shape(1, dimVar, 2, 64)), result); - var b = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 14, 64 })); - var reshapeb = Reshape(b, new ShapeConst(new[] { 1, dimC, -1 })); + var b = new Var(new TensorType(DataTypes.Float32, new Shape(1, dimVar, 14, 64))); + var reshapeb = Reshape(b, new Shape(1, dimC, -1)); var resultb = reshapeb.CheckedType; - Assert.Equal(new TensorType(DataTypes.Float32, new Dimension[] { 1, dimVar, 896 }), resultb); + Assert.Equal(new TensorType(DataTypes.Float32, new Shape(1, dimVar, 896)), resultb); } [Fact] @@ -348,8 +348,8 @@ public void TestConcatInfer() { var seq_len = new Var("seq_len", new TensorType(DataTypes.Int64, Shape.Scalar)); var hist_len = new Var("his_len", new TensorType(DataTypes.Int64, Shape.Scalar)); - var lhs = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, seq_len, 2, 64 })); - var rhs = new Var(new TensorType(DataTypes.Float32, new Dimension[] { 1, hist_len, 2, 64 })); + var lhs = new Var(new TensorType(DataTypes.Float32, new Shape(1, seq_len, 2, 64))); + var rhs = new Var(new TensorType(DataTypes.Float32, new Shape(1, hist_len, 2, 64))); var reshape = Concat(new IR.Tuple(new[] { lhs, rhs }), 1); var result = reshape.CheckedType; Assert.IsType(result); diff --git a/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs b/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs index 84c89b74f8..d9e5c2ac15 100644 --- a/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs +++ b/src/Nncase.Tests/Evaluator/UnitTestShapeEvaluator.cs @@ -48,7 +48,7 @@ public void TestConstant2() [Fact] public void TestWithVar() { - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown(), 6 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 6 })); var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); var newShape = new Expr[] { 1, 3, dimVar, 6 }; var varMap = new Dictionary { { input, newShape } }; @@ -182,7 +182,7 @@ public void UnitTestReshape() public void UnitTestGetItem() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown() })); + var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown })); var expr = input[1]; var dict = new Dictionary { { input, new[] { dimVar } } }; var shape = expr.EvaluateShapeExpr(dict); @@ -200,7 +200,7 @@ public void UnitTestGetItem() public void UnitTestGetItemSingle() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown() })); + var input = new Var(new TensorType(DataTypes.Int32, new[] { Dimension.Unknown })); var expr = input[0]; var dict = new Dictionary { { input, new[] { dimVar } } }; var shape = expr.EvaluateShapeExpr(dict); @@ -260,7 +260,7 @@ public void UnitTestPad() public void TestSpaceTobatch() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 192 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 192 })); var paddings = Tensor.From(new[] { 0, 1 }, [1, 2]); var expr = NCHWToNHWC(SpaceToBatch(NHWCToNCHW(input), new[] { 3 }, paddings)); var dict = new Dictionary { { input, new Expr[] { 1, dimVar, 192 } } }; @@ -280,7 +280,7 @@ public void TestSpaceTobatch() public void TestBatchToSpace() { var dimVar = new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); - var input = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown(), 69, 192 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown, 69, 192 })); var paddings = Tensor.From(new[] { 0, 1 }, [1, 2]); var expr = BatchToSpace(input, new[] { 3 }, paddings); var dict = new Dictionary { { input, new Expr[] { dimVar, 69, 192 } } }; @@ -357,7 +357,7 @@ private void TestOpShapeEval(Func exprCtor, Var input, Expr[] newSha private void TestOpShapeEval(Func exprCtor) { - var (input, newShape) = MakeInput(new[] { 1, 3, Dimension.Unknown(), 24 }); + var (input, newShape) = MakeInput(new[] { 1, 3, Dimension.Unknown, 24 }); TestOpShapeEval(exprCtor, input, newShape); } } diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs index e2c17ae390..e0a88f473e 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs @@ -37,7 +37,7 @@ public class UnitTestFlattenToReshape : TransformTestBase public static IEnumerable TestFlattenToReshapeNegativeData => new[] { - new object[] { new[] { 2, 4, IR.Dimension.Unknown() }, 1 }, + new object[] { new[] { 2, 4, IR.Dimension.Unknown }, 1 }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs index 89ead9bf60..d2fdea0d34 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs @@ -34,7 +34,7 @@ public class UnitTestReshapeBatchMatmul : TransformTestBase public static IEnumerable TestReshapeBatchMatmulNegativeData => new[] { - new object[] { new[] { 2, 1, 4 }, new[] { 4, Dimension.Unknown() } }, + new object[] { new[] { 2, 1, 4 }, new[] { 4, Dimension.Unknown } }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs index d02aaabbaf..563693c66f 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs @@ -34,7 +34,7 @@ public class UnitTestSpaceToBatchToPad : TransformTestBase new[] { new object[] { new[] { 1, 128, 128, new IR.Dimension(1) }, new[] { 2, 2 }, new[,] { { 0, 0 }, { 0, 0 } } }, - new object[] { new[] { 1, 128, 128, IR.Dimension.Unknown() }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } } }, + new object[] { new[] { 1, 128, 128, IR.Dimension.Unknown }, new[] { 1, 1 }, new[,] { { 1, 1 }, { 1, 1 } } }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs index f56c2a4d39..ece1de55f1 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs @@ -33,7 +33,7 @@ public class UnitTestSqueezeToReshape : TransformTestBase public static IEnumerable TestSqueezeToReshapeNegativeData => new[] { - new object[] { new[] { 2, 4, IR.Dimension.Unknown() }, Array.Empty() }, + new object[] { new[] { 2, 4, IR.Dimension.Unknown }, Array.Empty() }, }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs index 0f6b9dd023..f6b4f9c9bb 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs @@ -30,7 +30,7 @@ public class UnitTestUnSqueezeToReshape : TransformTestBase }; public static IEnumerable TestUnSqueezeToReshapeNegativeData => - new[] { new object[] { new[] { 2, 4, IR.Dimension.Unknown() }, new[] { -1 } }, }; + new[] { new object[] { new[] { 2, 4, IR.Dimension.Unknown }, new[] { -1 } }, }; [Theory] [MemberData(nameof(TestUnSqueezeToReshapePositiveData))] diff --git a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs index 19f7fccefa..cec7da22af 100644 --- a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs +++ b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs @@ -64,7 +64,7 @@ public void TestBucketPad() [Fact] public async Task TestSingleVarFusionBucket() { - var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var dimVar = Scalar("dimVar"); CompileOptions.ShapeBucketOptions.Enable = true; CompileOptions.ShapeBucketOptions.SegmentsCount = 2; @@ -73,7 +73,7 @@ public async Task TestSingleVarFusionBucket() CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; var input = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar }); var main = new Function("main", new Call(f, mainVar), mainVar); var shape = new Dictionary(); @@ -87,7 +87,7 @@ public async Task TestSingleVarFusionBucket() [Fact] public async Task TestRebuild() { - var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var dimVar = Scalar("dimVar"); CompileOptions.ShapeBucketOptions.Enable = true; CompileOptions.ShapeBucketOptions.SegmentsCount = 2; @@ -96,7 +96,7 @@ public async Task TestRebuild() CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; var input = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var shapeVar = new Var(new TensorType(DataTypes.Int64, new[] { 4 })); var body = IR.F.Math.MatMul(Reshape(fusionVar, shapeVar), fusionVar); var f = new BucketFusion("MatMul_0", "stackvm", body, new[] { fusionVar, shapeVar }, new[] { dimVar }); @@ -113,7 +113,7 @@ public async Task TestRebuild() [Fact] public async Task TestTupleOutput() { - var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var dimVar = Scalar("dimVar"); CompileOptions.ShapeBucketOptions.Enable = true; CompileOptions.ShapeBucketOptions.SegmentsCount = 2; @@ -122,7 +122,7 @@ public async Task TestTupleOutput() CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; var input = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var mm = IR.F.Math.MatMul(fusionVar, fusionVar); var body = new IR.Tuple(mm, mm); var f = new BucketFusion("MatMul_0", "stackvm", body, new[] { fusionVar }, new[] { dimVar }); @@ -138,8 +138,8 @@ public async Task TestTupleOutput() [Fact] public async Task TestDoubleVarFusionBucket() { - var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); - var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var dimVar1 = Scalar("dimVar1"); var dimVar2 = Scalar("dimVar2"); CompileOptions.ShapeBucketOptions.Enable = true; @@ -158,7 +158,7 @@ public async Task TestDoubleVarFusionBucket() var inputLhs = Testing.Rand(1, 3, 24, 24); var inputRhs = Testing.Rand(1, 3, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar1 }); var main = new Function("main", new Call(f, mainVarLhs, mainVarRhs), mainVarLhs, mainVarRhs); var shape = new Dictionary(); @@ -176,8 +176,8 @@ public async Task TestDoubleVarFusionBucket() [Fact] public async Task TestDoubleVarWithMultiDimEffect() { - var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); - var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown(), 1, 24, 24 })); + var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown, 1, 24, 24 })); var dimVar1 = Scalar("dimVar1"); var dimVar2 = Scalar("dimVar2"); CompileOptions.ShapeBucketOptions.Enable = true; @@ -196,7 +196,7 @@ public async Task TestDoubleVarWithMultiDimEffect() var inputLhs = Testing.Rand(1, 3, 24, 24); var inputRhs = Testing.Rand(3, 1, 24, 24); - var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown(), 24, 24 })); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar1 }); var main = new Function("main", new Call(f, mainVarLhs, mainVarRhs), mainVarLhs, mainVarRhs); var shape = new Dictionary(); diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs index a043596f1a..bab7175f5a 100644 --- a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs @@ -21,7 +21,7 @@ public class UnitTestFoldGetItemShapeOf : TransformTestBase [Fact] public void TestFoldGetItemShapeOf() { - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown(), 24 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 24 })); var data = Testing.Rand(1, 3, 24, 24); var dict = new Dictionary { { input, Value.FromTensor(data) } }; TestMatched(ShapeOf(input)[1], dict); @@ -30,7 +30,7 @@ public void TestFoldGetItemShapeOf() [Fact] public void TestFoldGetItemShapeOfWithCast() { - var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown(), 24 })); + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 24 })); var data = Testing.Rand(1, 3, 24, 24); var dict = new Dictionary { { input, Value.FromTensor(data) } }; TestMatched(Cast(ShapeOf(input), DataTypes.Int32)[1], dict); @@ -39,7 +39,7 @@ public void TestFoldGetItemShapeOfWithCast() [Fact] public void TestFoldGetItemShapeOfWithDynamic() { - var input = new Var(new TensorType(DataTypes.Int32, new[] { 1, 3, Dimension.Unknown(), 24 })); + var input = new Var(new TensorType(DataTypes.Int32, new[] { 1, 3, Dimension.Unknown, 24 })); TestNotMatch(ShapeOf(input)[2]); } }