diff --git a/.github/workflows/compiler-build.yml b/.github/workflows/compiler-build.yml
index 5348c3667..5ca2e7317 100644
--- a/.github/workflows/compiler-build.yml
+++ b/.github/workflows/compiler-build.yml
@@ -121,6 +121,7 @@ jobs:
${{github.workspace}}/install/lib
${{github.workspace}}/install/lib
+ 1
@@ -248,6 +249,7 @@ jobs:
shell: bash
env:
NNCASE_COMPILER: ${{github.workspace}}/install/Nncase.Compiler.dll
+ NNCASE_TILING_MAX_SOLUTIONS: 1
run: |
dotnet tool install --global dotnet-coverage
dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/onnx_basic.xml pytest tests/importer/onnx_/basic/ --doctest-modules --junitxml=test_results/onnx_basic.xml
diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs
index 0f2fd57ca..151339bb9 100644
--- a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs
+++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs
@@ -171,11 +171,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath)
var archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
"-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty;
-#if DEBUG
- var config = "Debug";
-#else
var config = "Release";
-#endif
var script = $"""
cd {sourcePath} &&
cmake -E remove_directory build &&
diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs
index 8e6976380..425399792 100644
--- a/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs
+++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/DeviceCSourceConvertVisitor.cs
@@ -324,7 +324,7 @@ protected override CSymbol VisitCall(Call expr)
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Matmul.cshtml", new TypedKernelTemplateModel(matmul)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
- Indent = string.Join(string.Empty, Enumerable.Repeat(' ', IndentScope.Writer.Indent)),
+ Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
@@ -332,9 +332,26 @@ protected override CSymbol VisitCall(Call expr)
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Pack.cshtml", new TypedKernelTemplateModel(pack)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
- Indent = string.Join(string.Empty, Enumerable.Repeat(' ', IndentScope.Writer.Indent)),
+ Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
+ case TIR.CPU.Transpose transpose:
+ IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Transpose.cshtml", new TypedKernelTemplateModel(transpose)
+ {
+ Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
+ Indent = new string(' ', IndentScope.Writer.Indent),
+ }).Result);
+ break;
+ case TIR.CPU.Unpack unpack:
+ IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unpack.cshtml", new TypedKernelTemplateModel(unpack)
+ {
+ Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
+ Indent = new string(' ', IndentScope.Writer.Indent),
+ }).Result);
+ break;
+ case TIR.CPU.Reduce reduce:
+ IndentScope.Writer.IndWrite($"reduce_{reduce.ReduceOp.ToC()}, fixed_shape<{string.Join(",", reduce.PackedAxes)}>, fixed_shape<{string.Join(",", reduce.PadedNums)}>>({arguments[0].Name}, {arguments[1].Name});\n");
+ break;
default:
throw new NotSupportedException();
}
diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Transpose.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Transpose.cshtml
index e7c1405e7..52d97e32a 100644
--- a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Transpose.cshtml
+++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/Kernels/Transpose.cshtml
@@ -1,4 +1,4 @@
@model Nncase.CodeGen.CPU.TypedKernelTemplateModel
@{
}
-transpose>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
+@(Model.Indent)transpose>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reduce.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reduce.cs
index 87c65d7e1..2f030b96e 100644
--- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reduce.cs
+++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reduce.cs
@@ -1,13 +1,15 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+using Google.OrTools.ConstraintSolver;
using Nncase.Evaluator;
using Nncase.IR;
+using Nncase.Schedule;
using Nncase.TIR.CPU;
namespace Nncase.Evaluator.TIR.CPU;
-public sealed class ReduceEvaluator : ITypeInferencer
+public sealed class ReduceEvaluator : ITypeInferencer, IKernelInfoEvaluator
{
public IRType Visit(ITypeInferenceContext context, Reduce target)
{
@@ -15,4 +17,22 @@ public IRType Visit(ITypeInferenceContext context, Reduce target)
context.CheckArgumentType(target, Reduce.Output);
return TupleType.Void;
}
+
+ public MicroKernelInfo Visit(Reduce op, MicroKernelContext context)
+ {
+ var domain = context.AccessMaps[0].Domains;
+ var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
+ var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray();
+ var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
+ var opt = (ICpuTargetOptions)context.TargetOptions;
+ bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
+ bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
+ return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
+ }
+
+ private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
+ {
+ var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
+ return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
+ }
}
diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs
index c769ce19e..456826393 100644
--- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs
+++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs
@@ -1,13 +1,15 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+using Google.OrTools.ConstraintSolver;
using Nncase.Evaluator;
using Nncase.IR;
+using Nncase.Schedule;
using Nncase.TIR.CPU;
namespace Nncase.Evaluator.TIR.CPU;
-public sealed class TransposeEvaluator : ITypeInferencer
+public sealed class TransposeEvaluator : ITypeInferencer, IKernelInfoEvaluator
{
public IRType Visit(ITypeInferenceContext context, Transpose target)
{
@@ -15,4 +17,22 @@ public IRType Visit(ITypeInferenceContext context, Transpose target)
context.CheckArgumentType(target, Transpose.Output);
return TupleType.Void;
}
+
+ public MicroKernelInfo Visit(Transpose op, MicroKernelContext context)
+ {
+ var domain = context.AccessMaps[0].Domains;
+ var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
+ var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray();
+ var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
+ var opt = (ICpuTargetOptions)context.TargetOptions;
+ bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
+ bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
+ return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
+ }
+
+ private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
+ {
+ var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
+ return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
+ }
}
diff --git a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs
index 7e4d46837..1f2573f02 100644
--- a/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs
+++ b/modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs
@@ -6,16 +6,36 @@
using System.Diagnostics;
using System.Linq;
using DryIoc.ImTools;
+using Google.OrTools.ConstraintSolver;
using Nncase.CostModel;
using Nncase.IR;
+using Nncase.Schedule;
using Nncase.TIR.CPU;
using Nncase.Utilities;
using OrtKISharp;
namespace Nncase.Evaluator.TIR.CPU;
-public sealed class UnpackEvaluator : ITypeInferencer
+public sealed class UnpackEvaluator : ITypeInferencer, IKernelInfoEvaluator
{
///
public IRType Visit(ITypeInferenceContext context, Unpack target) => TupleType.Void;
+
+ public MicroKernelInfo Visit(Unpack op, MicroKernelContext context)
+ {
+ var domain = context.AccessMaps[0].Domains;
+ var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
+ var multipliers = Enumerable.Repeat(new ValueRange(1, int.MaxValue), domain.Length).ToArray();
+ var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
+ var opt = (ICpuTargetOptions)context.TargetOptions;
+ bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
+ bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
+ return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
+ }
+
+ private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
+ {
+ var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
+ return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
+ }
}
diff --git a/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs b/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
index 17502ba0d..04088b8c5 100644
--- a/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
+++ b/modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
@@ -113,15 +113,15 @@ bool CheckField(Expr f)
}
// 3. reconstruction
- var constructor = new DistributedReConstructor(funcName, ModuleKind, condenseAlgo);
+ var constructor = new DistributedReconstructor(funcName, ModuleKind, condenseAlgo);
var post = constructor.Construct();
return post;
}
}
-internal sealed class DistributedReConstructor : ExprReConstructor
+internal sealed class DistributedReconstructor : ExprReconstructor
{
- public DistributedReConstructor(string funcName, string moduleKind, CondensationGraphAlgorithm algo)
+ public DistributedReconstructor(string funcName, string moduleKind, CondensationGraphAlgorithm algo)
: base(algo)
{
FuncName = funcName;
diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerReduce.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerReduce.cs
new file mode 100644
index 000000000..87da4112a
--- /dev/null
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerReduce.cs
@@ -0,0 +1,68 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.IR;
+using Nncase.IR.Affine;
+using Nncase.PatternMatch;
+using Nncase.Targets;
+using static Nncase.IR.TypePatternUtility;
+using static Nncase.PatternMatch.Utility;
+
+namespace Nncase.Passes.Rules.CPU.Affine;
+
+[RuleGenerator]
+public partial class LowerReduce : RewriteRule
+{
+ public LowerReduce(string moduleKind = CPUTarget.Kind)
+ {
+ ModuleKind = moduleKind;
+ }
+
+ public string ModuleKind { get; }
+
+ public override Pattern Pattern { get; } = IsCall(
+ "call",
+ IsOp("op"),
+ IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });
+
+ private Expr? GetReplace(Expr call, IR.CPU.PackedReduce op, Expr input)
+ {
+ var inputShape = input.CheckedShape.ToValueArray();
+ var rank = inputShape.Length;
+ var domains = IR.F.Affine.Domains(rank);
+ var outrank = call.CheckedShape.Rank;
+ var results = new AffineRange[outrank];
+ {
+ var j = 0;
+ for (int i = 0; i < rank; i++)
+ {
+ if (op.Axes.Contains(i))
+ {
+ if (op.KeepDims == true)
+ {
+ results[j++] = new AffineRange(0, 1);
+ }
+ }
+ else
+ {
+ results[j++] = new AffineRange(domains[i].Offset, domains[i].Extent);
+ }
+ }
+ }
+
+ var affinemap = new AffineMap(domains, default, results);
+ var outBuffer = call.CheckedType switch
+ {
+ TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
+ DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
+ _ => throw new ArgumentOutOfRangeException(nameof(call)),
+ };
+
+ return IR.F.Affine.Grid(ModuleKind)
+ .Domain(rank, out var _)
+ .Read(input, AffineMap.Identity(rank), out var intile)
+ .Write(outBuffer, affinemap, out var outTile)
+ .Body(TIR.F.CPU.Reduce(intile, outTile, op.PackedAxes.ToArray(), op.PadedNums.ToArray(), op.Axes, op.KeepDims, op.ReduceOp))
+ .Build();
+ }
+}
diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerTranspose.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerTranspose.cs
new file mode 100644
index 000000000..c7466250b
--- /dev/null
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerTranspose.cs
@@ -0,0 +1,54 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.IR;
+using Nncase.IR.Affine;
+using Nncase.IR.Math;
+using Nncase.PatternMatch;
+using Nncase.Targets;
+using static Nncase.IR.TypePatternUtility;
+using static Nncase.PatternMatch.Utility;
+
+namespace Nncase.Passes.Rules.CPU.Affine;
+
+[RuleGenerator]
+public partial class LowerTranspose : RewriteRule
+{
+ public LowerTranspose(string moduleKind = CPUTarget.Kind)
+ {
+ ModuleKind = moduleKind;
+ }
+
+ public string ModuleKind { get; }
+
+ public override Pattern Pattern { get; } = PatternMatch.F.Tensors.IsTranspose("trans", "call", IsWildcard("input"), IsTensorConst("perm"))
+ with
+ { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") };
+
+ private Expr? GetReplace(Expr call, Expr input, int[] perm)
+ {
+ var inputShape = input.CheckedShape.ToValueArray();
+ var rank = inputShape.Length;
+ var domains = IR.F.Affine.Domains(rank);
+ var results = new AffineRange[rank];
+ for (int i = 0; i < rank; i++)
+ {
+ results[perm[i]] = new AffineRange(domains[i].Offset, domains[i].Extent);
+ }
+
+ var inputAccessMap = new AffineMap(domains, default, results);
+ var outBuffer = call.CheckedType switch
+ {
+ TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
+ DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
+ _ => throw new ArgumentOutOfRangeException(nameof(call)),
+ };
+
+ return IR.F.Affine.Grid(ModuleKind)
+ .Domain(rank, out var _)
+ .Read(input, inputAccessMap, out var intile)
+ .Write(outBuffer, AffineMap.Identity(rank), out var outTile)
+ .Body(TIR.F.CPU.Transpose(intile, outTile, perm))
+ .Build();
+ }
+}
diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnpack.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnpack.cs
new file mode 100644
index 000000000..f0b2983e3
--- /dev/null
+++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerUnpack.cs
@@ -0,0 +1,64 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.IR;
+using Nncase.IR.Affine;
+using Nncase.IR.Math;
+using Nncase.PatternMatch;
+using Nncase.Targets;
+using static Nncase.IR.TypePatternUtility;
+using static Nncase.PatternMatch.Utility;
+
+namespace Nncase.Passes.Rules.CPU.Affine;
+
+[RuleGenerator]
+public partial class LowerUnpack : RewriteRule
+{
+ public LowerUnpack(string moduleKind = CPUTarget.Kind)
+ {
+ ModuleKind = moduleKind;
+ }
+
+ public string ModuleKind { get; }
+
+ public override Pattern Pattern { get; } = IsCall(
+ "call",
+ IsOp("op"),
+ IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });
+
+ private Expr? GetReplace(Expr call, IR.CPU.Unpack op, Expr input)
+ {
+ var inputShape = input.CheckedShape.ToValueArray();
+ var rank = inputShape.Length;
+ var domains = IR.F.Affine.Domains(rank);
+ var results = new AffineRange[rank];
+
+ for (int axis = 0; axis < rank; axis++)
+ {
+ // e.g. f32[128,256] -> f32<4>[32,256]
+ if (op.Axes.IndexOf(axis) is int i && i != -1)
+ {
+ results[axis] = new AffineRange(op.Lanes[i] * domains[axis].Offset, op.Lanes[i] * domains[axis].Extent);
+ }
+ else
+ {
+ results[axis] = new AffineRange(domains[axis].Offset, domains[axis].Extent);
+ }
+ }
+
+ var affinemap = new AffineMap(domains, default, results);
+ var outBuffer = call.CheckedType switch
+ {
+ TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
+ DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
+ _ => throw new ArgumentOutOfRangeException(nameof(call)),
+ };
+
+ return IR.F.Affine.Grid(ModuleKind)
+ .Domain(rank, out var _)
+ .Read(input, AffineMap.Identity(rank), out var intile)
+ .Write(outBuffer, affinemap, out var outTile)
+ .Body(TIR.F.CPU.Unpack(intile, outTile, op.Lanes, op.Axes))
+ .Build();
+ }
+}
diff --git a/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs b/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs
index 44249ca11..0ce0bdda3 100644
--- a/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs
+++ b/modules/Nncase.Modules.CPU/Passes/Tile/KernelToTIRVisitor.cs
@@ -106,7 +106,7 @@ protected override Unit VisitLeafCall(Call expr)
_mainBody.Add(TIR.F.CPU.Pack(arguments[0], ret, pack.Lanes, pack.Axes));
break;
case IR.CPU.Unpack unpack:
- _mainBody.Add(TIR.F.CPU.Unpack(arguments[0], ret, unpack.Axes));
+ _mainBody.Add(TIR.F.CPU.Unpack(arguments[0], ret, unpack.Lanes, unpack.Axes));
break;
case IR.CPU.PackedBinary packed_binary:
// _mainBody.Add(TIR.F.CPU.Binary(arguments[0], arguments[1], ret, packed_binary.BinaryOp, packed_binary.LhsPackedAxes, packed_binary.LhsPadedNums, packed_binary.RhsPackedAxes, packed_binary.RhsPadedNums));
diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs
index d5da9ea87..46fee31d4 100644
--- a/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs
+++ b/modules/Nncase.Modules.CPU/TIR/CPU/Functional.cs
@@ -62,9 +62,9 @@ public static Expr Pack(Expr input, Expr output, IRArray lanes, IRArray new Call(new Conv2D(stride, padding, dilation, groups, padMode, distributedType), input, weights, bias, output);
- public static Expr Unpack(Expr input, Expr output, IRArray axes)
+ public static Expr Unpack(Expr input, Expr output, IRArray lanes, IRArray axes)
{
- return new Call(new Unpack(axes), input, output);
+ return new Call(new Unpack(lanes, axes), input, output);
}
public static Expr PackedSoftmax(Expr input, Expr output, int axis, IRArray packedAxes)
@@ -117,7 +117,7 @@ public static Expr Gather(Buffer input, Buffer indcies, Buffer ret, int axis)
return new Call(new Gather(axis), input, indcies, ret);
}
- public static Expr Transpose(Buffer buffer, Buffer ret, int[] perm)
+ public static Expr Transpose(Expr buffer, Expr ret, int[] perm)
{
return new Call(new Transpose(perm), buffer, ret);
}
@@ -132,7 +132,7 @@ public static Expr Im2col(Buffer input, Buffer output, IRArray kernel, IRAr
return new Call(new Im2col(kernel, stride, padding, packedAxes, padedNums), input, output);
}
- public static Expr Reduce(Buffer input, Buffer ret, int[] packedAxes, int[] padedNums, IRArray axis, bool keepDims, ReduceOp reduceOp)
+ public static Expr Reduce(Expr input, Expr ret, int[] packedAxes, int[] padedNums, IRArray axis, bool keepDims, ReduceOp reduceOp)
{
return new Call(new TIR.CPU.Reduce(packedAxes, padedNums, axis, keepDims, reduceOp), input, ret);
}
diff --git a/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs b/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs
index 00b0df769..f635077bb 100644
--- a/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs
+++ b/modules/Nncase.Modules.CPU/TIR/CPU/Unpack.cs
@@ -24,6 +24,8 @@ public sealed partial class Unpack : CPUKernelOp
public static readonly ParameterInfo Output = new(typeof(Unpack), 1, "output", ParameterKind.Input);
+ public IRArray Lanes { get; }
+
public IRArray Axes { get; }
///
diff --git a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs
index 757c0c1bb..a72849284 100644
--- a/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs
+++ b/modules/Nncase.Modules.CPU/Targets/CPUTarget.cs
@@ -128,6 +128,10 @@ public void RegisterTargetDependentAfterQuantPass(IPassManager passManager, Comp
p.Add();
p.Add();
p.Add();
+ p.Add();
+ p.Add();
+
+ // p.Add();
});
// concat/reshape lower
diff --git a/ntt/include/nncase/ntt/kernels/transpose.h b/ntt/include/nncase/ntt/kernels/transpose.h
index 044f0b511..66c47d599 100644
--- a/ntt/include/nncase/ntt/kernels/transpose.h
+++ b/ntt/include/nncase/ntt/kernels/transpose.h
@@ -37,12 +37,12 @@ void transpose(const TIn &input, TOut &&output) {
constexpr auto output_shape = std::decay_t::shape();
constexpr auto output_strides = std::decay_t::strides();
constexpr auto output_rank = std::decay_t::rank();
+ constexpr auto input_rank = TIn::rank();
constexpr auto cdims_input = contiguous_dims(input_shape, input_strides);
constexpr auto cdims_output = contiguous_dims(output_shape, output_strides);
-
constexpr auto segs_cnt = segments_cnt();
- if (cdims_input == TIn::rank() && cdims_output == output_rank &&
+ if constexpr (cdims_input == input_rank && cdims_output == output_rank &&
segs_cnt <= 4) {
ntt::u_transpose(
input, output, std::make_index_sequence{});
@@ -67,4 +67,4 @@ void transpose(const TIn &input, TOut &&output) {
output(out_index) = input(index);
});
}
-} // namespace nncase::ntt
\ No newline at end of file
+} // namespace nncase::ntt
diff --git a/ntt/include/nncase/ntt/ukernels/u_transpose.h b/ntt/include/nncase/ntt/ukernels/u_transpose.h
index ff6fe0901..cf1f87519 100644
--- a/ntt/include/nncase/ntt/ukernels/u_transpose.h
+++ b/ntt/include/nncase/ntt/ukernels/u_transpose.h
@@ -145,7 +145,7 @@ class u_transpose {
template
-constexpr void u_transpose(const TIn &input, TOut &&output,
+constexpr void u_transpose(const TIn &input, TOut &output,
std::index_sequence) noexcept {
constexpr std::array dims_compressed =
diff --git a/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs b/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs
index a9b15ae42..8dace024b 100644
--- a/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs
+++ b/src/Nncase.Passes/GraphPartition/ExprReConstructor.cs
@@ -12,11 +12,11 @@
namespace Nncase.Passes.GraphPartition;
-public class ExprReConstructor
+public class ExprReconstructor
where TVertex : IExprVertex
where TEdge : class, IExprEdge
{
- public ExprReConstructor(CondensationGraphAlgorithm algo)
+ public ExprReconstructor(CondensationGraphAlgorithm algo)
{
Algo = algo;
ClusterMemo = new(ReferenceEqualityComparer.Instance);
diff --git a/src/Nncase.Schedule/Schedule/GraphTiler.cs b/src/Nncase.Schedule/Schedule/GraphTiler.cs
index 0cba90c4e..8bf585043 100644
--- a/src/Nncase.Schedule/Schedule/GraphTiler.cs
+++ b/src/Nncase.Schedule/Schedule/GraphTiler.cs
@@ -410,7 +410,7 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary();
foreach (var (node, diminfo) in tileableNodeMemo)
{
- searchAbleVars.AddRange(diminfo.TileVars.Select(i => i.Var()));
+ searchAbleVars.AddRange(diminfo.TileVars.Select(i => i.Var()).Reverse());
collector.Add(diminfo.TileVars.Select(i => i.Var()).ToArray());
collector.Add(diminfo.ForwardExtents.Select(x => x.Var()).ToArray());
}
@@ -465,7 +465,7 @@ public static TreeSolveResult SolvePrimGraph(TileNode primTree, Dictionary
+internal sealed class AutoTileReconstructor : ExprReconstructor
{
- public AutoTileReConstructor(GraphTiler tiler, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm algo)
+ public AutoTileReconstructor(GraphTiler tiler, string moduleKind, CompileOptions compileOptions, CondensationGraphAlgorithm algo)
: base(algo)
{
Tiler = tiler;
diff --git a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs
index 360685cd6..6810171d1 100644
--- a/src/Nncase.Tests/Affine/UnitTestTileGraph.cs
+++ b/src/Nncase.Tests/Affine/UnitTestTileGraph.cs
@@ -15,6 +15,7 @@
namespace Nncase.Tests.AffineTest;
+[Collection(nameof(TargetTest.NotThreadSafeResourceCollection))]
[TestFixture.AutoSetupTestMethod(InitSession = true)]
public sealed class UnitTestTileGraph : TestClassBase
{
diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
index efddd0194..02e9a6d78 100644
--- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
+++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
@@ -223,8 +223,7 @@ public async Task TestSwish(int[] shape, int count)
[Theory]
[InlineData(new object[] { new[] { 4, 8, 16, 32 }, new[] { 1 }, 0 })]
- [InlineData(new object[] { new[] { 4, 8, 16, 32 }, new[] { 2 }, 1 })]
- [InlineData(new object[] { new[] { 4, 8, 16, 32 }, new[] { 4 }, 2 })]
+ [InlineData(new object[] { new[] { 1, 64, 384, 128 }, new[] { 4 }, 1 })]
public async Task TestUnary(int[] shape, int[] hierarchy, int count)
{
var targetOptions = (CpuTargetOptions)CompileOptions.TargetOptions;
@@ -515,6 +514,7 @@ public async Task TestPackReshape(int[] inshape, int[] outshape, int packRank, i
[Theory]
[InlineData([new int[] { 2, 8, 16, 2 }, new int[] { 0, 2, 1, 3 }, 2, 0])]
+ [InlineData([new int[] { 1, 64, 384, 128 }, new int[] { 0, 2, 1, 3 }, 2, 1])]
public async Task TestTranspose(int[] shape, int[] perm, int rank, int number)
{
var input = new Var("input", new TensorType(DataTypes.Float32, shape));