Skip to content

Commit

Permalink
Fix type infer
Browse files Browse the repository at this point in the history
  • Loading branch information
sunnycase committed Jan 15, 2025
1 parent 07932ae commit 635346b
Show file tree
Hide file tree
Showing 23 changed files with 345 additions and 181 deletions.
2 changes: 2 additions & 0 deletions modules/Nncase.Modules.CPU/Targets/CPUTarget.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Nncase.CodeGen.StackVM;
using Nncase.IR;
using Nncase.Passes;
using Nncase.Passes.Rules.ShapeBucket;
using Nncase.Passes.Transforms;
using Nncase.Quantization;

Expand Down Expand Up @@ -121,6 +122,7 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp
p.Add<Passes.Rules.Neutral.FoldConstCall>();
});

passManager.Add<AddFunctionToModule>();
passManager.Add<CPUFunctionPartitionPass>();

passManager.Add<CPUFusionToModulePass>();
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.UnSqueezeToReshape>();
p.Add<Passes.Rules.ShapeExpr.GatherToGetItem>();
p.Add<Passes.Rules.ShapeExpr.FoldGetItemShapeOf>();
p.Add<Passes.Rules.Neutral.FoldGetItemConcat>();
p.Add<Passes.Rules.Neutral.FoldIf>();
p.Add<Passes.Rules.Neutral.FoldNopReduce>();
p.Add<Passes.Rules.Neutral.SliceToGetItem>();
p.Add<Passes.Rules.Neutral.FoldTwoPads>();
Expand Down
11 changes: 11 additions & 0 deletions src/Nncase.Core/CompilerServices.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
using System.Threading.Tasks;
using DryIoc;
using Microsoft.Extensions.DependencyInjection;
using NetFabric.Hyperlinq;
using Nncase.CostModel;
using Nncase.Evaluator;
using Nncase.IR;
Expand Down Expand Up @@ -532,6 +533,16 @@ public static void DumpPatternIR(Expr expr, string prefix, string dumpDir) =>

public static Expr SimplifyForDimension(Expr value) => Provider.SimplifyForDimension(value);

public static Expr FastSimplifyForDimension(Expr value)
{
if (value is Call call && call.Arguments.AsValueEnumerable().All(x => x is Const))
{
return SimplifyForDimension(value);
}

return value;
}

internal static DryIoc.IContainer CreateScope()
{
var container = (DryIoc.IContainer)_serviceProvider!;
Expand Down
7 changes: 6 additions & 1 deletion src/Nncase.Core/Evaluator/ITypeInferenceContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using NetFabric.Hyperlinq;
using Nncase.IR;

namespace Nncase.Evaluator;
Expand All @@ -23,7 +24,11 @@ public interface ITypeInferenceContext
/// <returns>The argument expression.</returns>
Expr GetArgument(Op op, ParameterInfo parameter);

Expr GetDimensionArgument(Op op, ParameterInfo parameter) => CompilerServices.SimplifyForDimension(GetArgument(op, parameter));
Expr GetDimensionArgument(Op op, ParameterInfo parameter)
{
var arg = GetArgument(op, parameter);
return CompilerServices.FastSimplifyForDimension(arg);
}

/// <summary>
/// Get arguments expression.
Expand Down
13 changes: 10 additions & 3 deletions src/Nncase.Core/IR/Dimension.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ public Dimension(long value)

public Dimension(Expr value)
{
value = CompilerServices.SimplifyForDimension(value);
value = CompilerServices.FastSimplifyForDimension(value);
if (value is TensorConst tc)
{
Kind = DimensionKind.Fixed;
_fixedValue = tc.Value.ToScalar<int>();
_fixedValue = tc.Value.ToScalar<long>();
}
else
{
Expand Down Expand Up @@ -129,6 +129,8 @@ public long FixedValue
public static Dimension operator +(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch
{
(true, true) => lhs.FixedValue + rhs.FixedValue,
(true, _) when lhs.FixedValue == 0 => rhs,
(_, true) when rhs.FixedValue == 0 => lhs,
(_, _) => new Dimension(lhs.Value + rhs.Value),
};

Expand All @@ -137,12 +139,15 @@ public long FixedValue
public static Dimension operator -(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch
{
(true, true) => lhs.FixedValue - rhs.FixedValue,
(_, true) when rhs.FixedValue == 0 => lhs,
(_, _) => new Dimension(lhs.Value - rhs.Value),
};

public static Dimension operator *(Dimension lhs, Dimension rhs) => (lhs.IsFixed, rhs.IsFixed) switch
{
(true, true) => lhs.FixedValue * rhs.FixedValue,
(true, _) when lhs.FixedValue == 1 => rhs,
(_, true) when rhs.FixedValue == 1 => lhs,
(_, _) => new Dimension(lhs.Value * rhs.Value),
};

Expand All @@ -159,7 +164,7 @@ public static Dimension Abs(Dimension value)
return System.Math.Abs(value.FixedValue);
}

return IR.F.Math.Abs(value.Value);
return value.Value.Metadata.Range.Min >= 0 ? value.Value : IR.F.Math.Abs(value.Value);
}

public static Dimension Clamp(Dimension value, Dimension min, Dimension max)
Expand Down Expand Up @@ -230,5 +235,7 @@ public bool IsAssignableFrom(Dimension dimension)

return dimension.Kind == DimensionKind.Fixed && Value == dimension.Value;
}

public Expr ToExpr() => IsFixed ? FixedValue : Value;
}
}
4 changes: 3 additions & 1 deletion src/Nncase.Core/IR/Expr.Conversion.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,9 @@ public abstract partial class Expr
/// Create <see cref="Expr"/> from a <see cref="Shape"/>.
/// </summary>
/// <param name="shape">Shape.</param>
public static implicit operator Expr(Shape shape) => Const.FromShape(shape);
public static implicit operator Expr(Shape shape) =>
shape.IsFixed ? Const.FromShape(shape)
: IR.F.Tensors.Stack(new IR.Tuple(shape.Select(x => x.ToExpr()).ToArray()), 0);

/// <summary>
/// Create <see cref="Expr"/> from an array of<see cref="int"/>.
Expand Down
16 changes: 16 additions & 0 deletions src/Nncase.Core/IR/Expr.Operators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR.Tensors;
using static Nncase.IR.F.Math;

namespace Nncase.IR;
Expand All @@ -21,6 +22,21 @@ public partial class Expr
/// <param name="index"> expr. </param>
public Expr this[Expr index] => F.Tensors.GetItem(this, index);

/// <summary>
/// get the item from the expr.
/// </summary>
/// <returns> expr. </returns>
public Expr this[params long[] indices] =>
this switch
{
TensorConst tc => Tensor.FromScalar(tc.Value.ElementType, tc.Value[indices]),
TupleConst tc => tc.Value[(int)indices.Single()].AsTensor(),
IR.Tuple t => t.Fields[(int)indices.Single()],
Call { Target: Concat { Axis: 0 } } c when indices.Length == 1 => c[Concat.Input][indices[0]][0],
Call { Target: Reshape } c when c[Reshape.Shape] is TensorConst tc && tc.Value.Length == 1 && tc.Value.ToScalar<long>() == 1 => c[Reshape.Input],
_ => this[indices.Select(x => (Expr)x).ToArray()],
};

/// <summary>
/// get the item from the expr.
/// </summary>
Expand Down
2 changes: 2 additions & 0 deletions src/Nncase.Core/IR/Expr.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public class IRMetadata
/// Gets or sets outputs names.
/// </summary>
public IReadOnlyList<string>? OutputNames { get; set; }

public ValueRange<double> Range { get; set; } = ValueRange<double>.Full;
}

/// <summary>
Expand Down
96 changes: 94 additions & 2 deletions src/Nncase.Core/IR/Shape.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,70 @@ public enum ShapeKind
Fixed,
}

public record struct FixedAndDynamicDimension(long Fixed, Dimension? Dynamic)
{
public static implicit operator FixedAndDynamicDimension((long Fixed, Dimension? Dynamic) value) => new FixedAndDynamicDimension(value.Fixed, value.Dynamic);

public static FixedAndDynamicDimension operator *(FixedAndDynamicDimension a, FixedAndDynamicDimension b)
{
var dyn = (a.Dynamic, b.Dynamic) switch
{
(null, null) => null,
(null, Dimension x) => x,
(Dimension x, null) => x,
(Dimension x, Dimension y) => x * y,
};
return new FixedAndDynamicDimension(a.Fixed * b.Fixed, dyn);
}

public static FixedAndDynamicDimension operator /(FixedAndDynamicDimension a, long b)
{
if (a.Fixed % b == 0 || a.Dynamic is null)
{
return new FixedAndDynamicDimension(a.Fixed / b, a.Dynamic);
}

return new FixedAndDynamicDimension(1, a.Fixed * a.Dynamic.Value / b);
}

public static FixedAndDynamicDimension operator /(FixedAndDynamicDimension a, FixedAndDynamicDimension b)
{
if (a.Fixed % b.Fixed == 0)
{
return (a.Dynamic, b.Dynamic) switch
{
(null, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null),
(null, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / b.Fixed / y),
(Dimension x, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, x),
(Dimension x, Dimension y) when x == y => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null),
(Dimension x, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / b.Fixed * x / y),
};
}

return (a.Dynamic, b.Dynamic) switch
{
(null, null) => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null),
(null, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed / (b.Fixed * y)),
(Dimension x, null) => new FixedAndDynamicDimension(1, a.Fixed * x / b.Fixed),
(Dimension x, Dimension y) when x == y => new FixedAndDynamicDimension(a.Fixed / b.Fixed, null),
(Dimension x, Dimension y) => new FixedAndDynamicDimension(1, a.Fixed * x / (b.Fixed * y)),
};
}

public static FixedAndDynamicDimension Abs(FixedAndDynamicDimension value) =>
new(System.Math.Abs(value.Fixed), value.Dynamic is null ? null : Dimension.Abs(value.Dynamic.Value));

public Dimension ToDimension()
{
return Dynamic is null ? new Dimension(Fixed) : new Dimension(Fixed) * Dynamic.Value;
}

public Expr ToExpr()
{
return Dynamic is null ? Fixed : Fixed * Dynamic.Value;
}
}

/// <summary>
/// Tensor shape.
/// </summary>
Expand Down Expand Up @@ -265,8 +329,7 @@ public static Shape FromExpr(Expr value)
}

var rank = (int)shape[0].FixedValue;
// return new Shape(Enumerable.Range(0, rank).Select(x => (Dimension)value[x]));
return Shape.Unknown(rank);
return new Shape(Enumerable.Range(0, rank).Select(x => (Dimension)value[x]));
}

/// <summary>
Expand All @@ -277,6 +340,25 @@ public Dimension Prod()
return _dimensions.Aggregate(new Dimension(1), (x, y) => x * y);
}

public FixedAndDynamicDimension ProdFixedAndDynamic()
{
var fixedValue = 1L;
var dynamicValue = new Dimension(1);
foreach (var dim in _dimensions)
{
if (dim.IsFixed)
{
fixedValue *= dim.FixedValue;
}
else
{
dynamicValue *= dim;
}
}

return new(fixedValue, dynamicValue.IsFixed ? null : dynamicValue);
}

/// <summary>
/// return new shape after insert dim.
/// </summary>
Expand Down Expand Up @@ -389,6 +471,16 @@ public bool IsAssignableFrom(Shape shape)
return true;
}

public IR.Tuple ToTuple()
{
if (IsUnranked)
{
throw new InvalidOperationException("Cannot convert unranked shape to tuple");
}

return new IR.Tuple(_dimensions.Select(x => x.ToExpr()).ToArray());
}

private static ShapeKind KindOf(ReadOnlySpan<Dimension> dimensions)
{
return dimensions.AsValueEnumerable().Any(x => x.IsUnknown) ? ShapeKind.HasUnknownDimension : ShapeKind.Fixed;
Expand Down
14 changes: 14 additions & 0 deletions src/Nncase.Core/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
using System.Text;
using System.Threading.Tasks;
using CommunityToolkit.HighPerformance;
using Google.Protobuf.WellKnownTypes;
using Nncase.Buffers;
using Nncase.IR;
using Nncase.TIR;
Expand Down Expand Up @@ -185,6 +186,19 @@ public static Tensor<T> FromScalar<T>(T value)
return tensor;
}

/// <summary>
/// Create a scalar tensor from a scalar.
/// </summary>
/// <param name="type">Data type.</param>
/// <param name="value">Value.</param>
/// <returns>Created tensor.</returns>
public static Tensor FromScalar(DataType type, object value)
{
var tensor = Zeros(type, ReadOnlySpan<long>.Empty);
tensor[0] = value;
return tensor;
}

/// <summary>
/// Create a 1-D tensor from a scalar.
/// </summary>
Expand Down
9 changes: 2 additions & 7 deletions src/Nncase.Core/TensorUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,6 @@ public static T GetProductGeneric<T>(ReadOnlySpan<T> dimensions, int startIndex
T product = T.One;
for (int i = startIndex; i < dimensions.Length; i++)
{
if (dimensions[i] < T.Zero)
{
throw new ArgumentOutOfRangeException($"{nameof(dimensions)}[{i}]");
}

// we use a long which should be much larger than is ever used here,
// but still force checked
checked
Expand Down Expand Up @@ -75,10 +70,10 @@ public static Expr GetProduct(ReadOnlySpan<Expr> dimensions, int startIndex = 0)
{
if (dimensions.Length == 0)
{
return 1;
return 1L;
}

Expr product = 1;
Expr product = 1L;
for (int i = startIndex; i < dimensions.Length; i++)
{
var dimension = dimensions[i];
Expand Down
25 changes: 19 additions & 6 deletions src/Nncase.Core/Utilities/ShapeExprUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,24 +28,37 @@ public static Expr Positive(Expr axis, Expr inShape)

public static Expr If(Expr condition, Func<Var, Expr> thenExpr, Func<Var, Expr> elseExpr, Expr arg)
{
var var1 = new Var();
var var2 = new Var();
var var1 = new Var(arg.CheckedType);
var var2 = var1.With();
var thenFunc = new Function(thenExpr(var1), var1);
var elseFunc = new Function(elseExpr(var2), var2);
return new If(condition, thenFunc, elseFunc, arg);
}

public static Expr If(Expr condition, Func<Var, Var, Expr> thenExpr, Func<Var, Var, Expr> elseExpr, Expr arg1, Expr arg2)
{
var var11 = new Var();
var var21 = new Var();
var var12 = new Var();
var var22 = new Var();
var var11 = new Var(arg1.CheckedType);
var var21 = var11.With();
var var12 = new Var(arg2.CheckedType);
var var22 = var12.With();
var thenFunc = new Function(thenExpr(var11, var12), var11, var12);
var elseFunc = new Function(elseExpr(var21, var22), var21, var22);
return new If(condition, thenFunc, elseFunc, arg1, arg2);
}

public static Expr If(Expr condition, Func<Var, Var, Var, Expr> thenExpr, Func<Var, Var, Var, Expr> elseExpr, Expr arg1, Expr arg2, Expr arg3)
{
var var11 = new Var(arg1.CheckedType);
var var21 = var11.With();
var var12 = new Var(arg2.CheckedType);
var var22 = var12.With();
var var13 = new Var(arg3.CheckedType);
var var23 = var13.With();
var thenFunc = new Function(thenExpr(var11, var12, var13), var11, var12, var13);
var elseFunc = new Function(elseExpr(var21, var22, var23), var21, var22, var23);
return new If(condition, thenFunc, elseFunc, arg1, arg2, arg3);
}

public static Expr Slice(Expr shape, int begin, int end)
{
return IR.F.Tensors.Slice(CheckShape(shape), new[] { begin }, new[] { end }, 1);
Expand Down
Loading

0 comments on commit 635346b

Please sign in to comment.