Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/Further Tiling #1292

Merged
merged 5 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/compiler-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ jobs:
<EnvironmentVariables>
<LD_LIBRARY_PATH>${{github.workspace}}/install/lib</LD_LIBRARY_PATH>
<DYLD_LIBRARY_PATH>${{github.workspace}}/install/lib</DYLD_LIBRARY_PATH>
<NNCASE_TILING_MAX_SOLUTIONS>1</NNCASE_TILING_MAX_SOLUTIONS>
</EnvironmentVariables>
</RunConfiguration>
</RunSettings>
Expand Down Expand Up @@ -248,6 +249,7 @@ jobs:
shell: bash
env:
NNCASE_COMPILER: ${{github.workspace}}/install/Nncase.Compiler.dll
NNCASE_TILING_MAX_SOLUTIONS: 1
run: |
dotnet tool install --global dotnet-coverage
dotnet-coverage collect -s tools/dotnet_coverage.settings.xml -f cobertura -o coverage/onnx_basic.xml pytest tests/importer/onnx_/basic/ --doctest-modules --junitxml=test_results/onnx_basic.xml
Expand Down
4 changes: 0 additions & 4 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,7 @@ private string ArgumentsSpecific(string sourcePath, string outPath)
var archConfig = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ?
"-DCMAKE_C_COMPILER=clang-cl -DCMAKE_CXX_COMPILER=clang-cl" : string.Empty;

#if DEBUG
var config = "Debug";
#else
var config = "Release";
#endif
var script = $"""
cd {sourcePath} &&
cmake -E remove_directory build &&
Expand Down
7 changes: 4 additions & 3 deletions modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ internal static class CSourceExtensions
{ DataTypes.UInt16, "uint16_t" },
{ DataTypes.UInt32, "uint32_t" },
{ DataTypes.UInt64, "uint64_t" },
{ DataTypes.Float16, "half" },
{ DataTypes.Float16, "nncase::half" },
{ DataTypes.BFloat16, "nncase::bfloat16" },
{ DataTypes.Float32, "float" },
{ DataTypes.Float64, "double" },
{ DataTypes.Float8E4M3, "float_e4m3_t" },
{ DataTypes.Float8E5M2, "float_e5m2_t" },
{ DataTypes.Float8E4M3, "nncase::float_e4m3_t" },
{ DataTypes.Float8E5M2, "nncase::float_e5m2_t" },
};

public static string ToC(this PrimType primType) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,17 +324,41 @@ protected override CSymbol VisitCall(Call expr)
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Matmul.cshtml", new TypedKernelTemplateModel<TIR.CPU.Matmul>(matmul)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = string.Join(string.Empty, Enumerable.Repeat(' ', IndentScope.Writer.Indent)),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);

break;
case TIR.CPU.Pack pack:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Pack.cshtml", new TypedKernelTemplateModel<TIR.CPU.Pack>(pack)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = string.Join(string.Empty, Enumerable.Repeat(' ', IndentScope.Writer.Indent)),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Transpose transpose:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Transpose.cshtml", new TypedKernelTemplateModel<TIR.CPU.Transpose>(transpose)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Unpack unpack:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Unpack.cshtml", new TypedKernelTemplateModel<TIR.CPU.Unpack>(unpack)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Reduce reduce:
IndentScope.Writer.Write(RazorTemplateEngine.RenderAsync("~/CodeGen/CPU/Templates/Kernels/Reduce.cshtml", new TypedKernelTemplateModel<TIR.CPU.Reduce>(reduce)
{
Arguments = arguments.Select(x => new KernelArgument { Symbol = x }).ToArray(),
Indent = new string(' ', IndentScope.Writer.Indent),
}).Result);
break;
case TIR.CPU.Cast cast:
IndentScope.Writer.IndWrite($"cast({arguments[0].Name}, {arguments[1].Name});\n");
break;
default:
throw new NotSupportedException();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
@model Nncase.CodeGen.CPU.TypedKernelTemplateModel<Nncase.TIR.CPU.Reduce>
@(Model.Indent)if (@Html.Raw(Model.Arguments[2].Symbol.Name)) {
@(Model.Indent) reduce_@(Model.Target.ReduceOp.ToC())<fixed_shape<@(string.Join(",", Model.Target.Axes))>, fixed_shape<@(string.Join(",", Model.Target.PackedAxes))>, fixed_shape<@(string.Join(",", Model.Target.PadedNums))>, true>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
@(Model.Indent)} else {
@(Model.Indent) reduce_@(Model.Target.ReduceOp.ToC())<fixed_shape<@(string.Join(",", Model.Target.Axes))>, fixed_shape<@(string.Join(",", Model.Target.PackedAxes))>, fixed_shape<@(string.Join(",", Model.Target.PadedNums))>, false>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
@(Model.Indent)}

Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@model Nncase.CodeGen.CPU.TypedKernelTemplateModel<Nncase.TIR.CPU.Transpose>
@{
}
transpose<fixed_shape<@string.Join(",", Model.Target.Perm)>>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
@(Model.Indent)transpose<fixed_shape<@string.Join(",", Model.Target.Perm)>>(@Html.Raw(Model.Arguments[0].Symbol.Name), @Html.Raw(Model.Arguments[1].Symbol.Name));
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

#include <nncase/ntt/runtime.h>
#include "topo_aware_runtime.h"
#include <nncase/float8.h>
#include <nncase/half.h>
#include <nncase/bfloat16.h>
@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).SkipLast(1)){
@:uint8_t L@(i+1)Data[@(s)];
}
Expand Down
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Cast.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,32 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Google.OrTools.ConstraintSolver;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class CastEvaluator : ITypeInferencer<Cast>
public sealed class CastEvaluator : ITypeInferencer<Cast>, IKernelInfoEvaluator<Cast>
{
public IRType Visit(ITypeInferenceContext context, Cast target) => TupleType.Void;

public MicroKernelInfo Visit(Cast op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Reduce.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Google.OrTools.ConstraintSolver;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class ReduceEvaluator : ITypeInferencer<Reduce>
public sealed class ReduceEvaluator : ITypeInferencer<Reduce>, IKernelInfoEvaluator<Reduce>
{
public IRType Visit(ITypeInferenceContext context, Reduce target)
{
context.CheckArgumentType<TensorType>(target, Reduce.Input);
context.CheckArgumentType<TensorType>(target, Reduce.Output);
return TupleType.Void;
}

public MicroKernelInfo Visit(Reduce op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Transpose.cs
Original file line number Diff line number Diff line change
@@ -1,18 +1,38 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using Google.OrTools.ConstraintSolver;
using Nncase.Evaluator;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class TransposeEvaluator : ITypeInferencer<Transpose>
public sealed class TransposeEvaluator : ITypeInferencer<Transpose>, IKernelInfoEvaluator<Transpose>
{
public IRType Visit(ITypeInferenceContext context, Transpose target)
{
context.CheckArgumentType<TensorType>(target, Transpose.Input);
context.CheckArgumentType<TensorType>(target, Transpose.Output);
return TupleType.Void;
}

public MicroKernelInfo Visit(Transpose op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
22 changes: 21 additions & 1 deletion modules/Nncase.Modules.CPU/Evaluator/TIR/CPU/Unpack.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,36 @@
using System.Diagnostics;
using System.Linq;
using DryIoc.ImTools;
using Google.OrTools.ConstraintSolver;
using Nncase.CostModel;
using Nncase.IR;
using Nncase.Schedule;
using Nncase.TIR.CPU;
using Nncase.Utilities;
using OrtKISharp;

namespace Nncase.Evaluator.TIR.CPU;

public sealed class UnpackEvaluator : ITypeInferencer<Unpack>
public sealed class UnpackEvaluator : ITypeInferencer<Unpack>, IKernelInfoEvaluator<Unpack>
{
/// <inheritdoc/>
public IRType Visit(ITypeInferenceContext context, Unpack target) => TupleType.Void;

public MicroKernelInfo Visit(Unpack op, MicroKernelContext context)
{
var domain = context.AccessMaps[0].Domains;
var primitives = Enumerable.Repeat(1, domain.Length).ToArray();
var multipliers = Enumerable.Repeat(new ValueRange<int>(1, int.MaxValue), domain.Length).ToArray();
var bufferInfos = new MicroKernelBufferInfo[context.BufferShapes.Length];
var opt = (ICpuTargetOptions)context.TargetOptions;
bufferInfos[0] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Read);
bufferInfos[1] = new(opt.MemoryBandWidths[1], opt.MemoryBandWidths[1], MicroKernelBufferInfo.BufferState.Write);
return new MicroKernelInfo(primitives, multipliers, bufferInfos, GetComputeCycle);
}

private static IntExpr GetComputeCycle(IntExpr[][] bufferShapes, Solver solver, MicroKernelContext context)
{
var factor = System.Math.Min(context.BufferShapes[0][^1], 32);
return factor * (1 + solver.MakeIsLessVar(bufferShapes[0][^1], solver.MakeIntConst(factor)));
}
}
6 changes: 3 additions & 3 deletions modules/Nncase.Modules.CPU/Passes/CPUFunctionPartitionPass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ bool CheckField(Expr f)
}

// 3. reconstruction
var constructor = new DistributedReConstructor(funcName, ModuleKind, condenseAlgo);
var constructor = new DistributedReconstructor(funcName, ModuleKind, condenseAlgo);
var post = constructor.Construct();
return post;
}
}

internal sealed class DistributedReConstructor : ExprReConstructor<ExprVertex, ExprEdge>
internal sealed class DistributedReconstructor : ExprReconstructor<ExprVertex, ExprEdge>
{
public DistributedReConstructor(string funcName, string moduleKind, CondensationGraphAlgorithm<ExprVertex, ExprEdge> algo)
public DistributedReconstructor(string funcName, string moduleKind, CondensationGraphAlgorithm<ExprVertex, ExprEdge> algo)
: base(algo)
{
FuncName = funcName;
Expand Down
9 changes: 0 additions & 9 deletions modules/Nncase.Modules.CPU/Passes/PassUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@ public static bool IsCpuSupported(Op op, Expr expr, IEnumerable<Expr> arguments,

break;

case IR.Tensors.Cast cast:
var inType = arguments.ToArray()[0].CheckedDataType;
if (inType == DataTypes.Float16 || inType == DataTypes.BFloat16 || cast.NewType == DataTypes.Float16 || cast.NewType == DataTypes.BFloat16)
{
return false;
}

break;

case IR.Tensors.Expand expand:
if (arguments.ToArray()[0].CheckedShape.Rank != arguments.ToArray()[1].CheckedShape.Size)
{
Expand Down
55 changes: 55 additions & 0 deletions modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerCast.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Nncase.IR;
using Nncase.IR.Affine;
using Nncase.IR.Math;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using Nncase.Targets;
using static Nncase.IR.F.CPU;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.Utility;

namespace Nncase.Passes.Rules.CPU.Affine;

[RuleGenerator]
public partial class LowerCast : RewriteRule<Pattern>
{
public LowerCast(string moduleKind = CPUTarget.Kind)
{
ModuleKind = moduleKind;
}

public string ModuleKind { get; }

/// <inheritdoc/>
public override Pattern Pattern { get; } = PatternMatch.F.Tensors.IsCast(
"cast",
"call",
_ => true,
IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });

private Expr GetReplace(Cast cast, Expr call, Expr input)
{
var outBuffer = call.CheckedType switch
{
TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
_ => throw new ArgumentOutOfRangeException(nameof(input)),
};

var rank = input.CheckedShape.Rank;
return IR.F.Affine.Grid(ModuleKind)
.Domain(rank, out var _)
.Read(input, AffineMap.Identity(rank), out var inTile)
.Write(outBuffer, AffineMap.Identity(rank), out var outTile)
.Body(TIR.F.CPU.Cast(inTile, outTile, cast.NewType, cast.CastMode))
.Build();
}
}
Loading
Loading