Skip to content

Commit

Permalink
fix cast tiling
Browse files Browse the repository at this point in the history
  • Loading branch information
zhen8838 committed Feb 7, 2025
1 parent b721565 commit 50afe78
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 16 deletions.
3 changes: 2 additions & 1 deletion modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ internal static class CSourceExtensions
{ DataTypes.UInt16, "uint16_t" },
{ DataTypes.UInt32, "uint32_t" },
{ DataTypes.UInt64, "uint64_t" },
{ DataTypes.Float16, "half" },
{ DataTypes.Float16, "nncase::half" },
{ DataTypes.BFloat16, "nncase::bfloat16" },
{ DataTypes.Float32, "float" },
{ DataTypes.Float64, "double" },
{ DataTypes.Float8E4M3, "nncase::float_e4m3_t" },
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <nncase/ntt/runtime.h>
#include "topo_aware_runtime.h"
#include <nncase/float8.h>
#include <nncase/half.h>
#include <nncase/bfloat16.h>
@foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).SkipLast(1)){
@:uint8_t L@(i+1)Data[@(s)];
}
Expand Down
9 changes: 0 additions & 9 deletions modules/Nncase.Modules.CPU/Passes/PassUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,15 +87,6 @@ public static bool IsCpuSupported(Op op, Expr expr, IEnumerable<Expr> arguments,

break;

case IR.Tensors.Cast cast:
var inType = arguments.ToArray()[0].CheckedDataType;
if (inType == DataTypes.Float16 || inType == DataTypes.BFloat16 || cast.NewType == DataTypes.Float16 || cast.NewType == DataTypes.BFloat16)
{
return false;
}

break;

case IR.Tensors.Expand expand:
if (arguments.ToArray()[0].CheckedShape.Rank != arguments.ToArray()[1].CheckedShape.Size)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ public LowerCast(string moduleKind = CPUTarget.Kind)
_ => true,
IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") });

private Expr GetReplace(Cast cast, Expr input)
private Expr GetReplace(Cast cast, Expr call, Expr input)
{
var outBuffer = input.CheckedType switch
var outBuffer = call.CheckedType switch
{
TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()),
DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement),
Expand Down
11 changes: 7 additions & 4 deletions src/Nncase.Tests/Targets/UnitTestCPUKernels.cs
Original file line number Diff line number Diff line change
Expand Up @@ -335,12 +335,15 @@ public async Task TestResizeImage(int[] shape, ImageResizeMode resizeMode, int[]
}

[Theory]
[InlineData(new object[] { new[] { 1, 256, 64, 64 }, Runtime.TypeCode.Float8E4M3, 0 })]
public async Task TestPackCast(int[] shape, Nncase.Runtime.TypeCode type, int count)
[InlineData(new object[] { new[] { 1, 256, 64, 64 }, Runtime.TypeCode.Float8E4M3, Runtime.TypeCode.Float32, 0 })]
[InlineData(new object[] { new[] { 1, 64, 64, 256 }, Runtime.TypeCode.Float16, Runtime.TypeCode.BFloat16, 1 })]
[InlineData(new object[] { new[] { 1, 64, 256, 64 }, Runtime.TypeCode.BFloat16, Runtime.TypeCode.Float16, 2 })]
public async Task TestPackCast(int[] shape, Nncase.Runtime.TypeCode type1, Nncase.Runtime.TypeCode type2, int count)
{
var input = new Var(new TensorType(DataTypes.Float32, shape));
var casted = IR.F.Tensors.Cast(input, DataType.FromTypeCode(type));
var pre = IR.F.Tensors.Cast(casted, DataTypes.Float32);
var casted1 = IR.F.Tensors.Cast(input, DataType.FromTypeCode(type1));
var casted2 = IR.F.Tensors.Cast(casted1, DataType.FromTypeCode(type2));
var pre = IR.F.Tensors.Cast(casted2, DataTypes.Float32);

var feedDict = new Dictionary<Var, IValue>() {
{ input, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, shape).Evaluate() },
Expand Down

0 comments on commit 50afe78

Please sign in to comment.