diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs index 6b686a405..9a3344056 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/CSourceExtensions.cs @@ -22,7 +22,8 @@ internal static class CSourceExtensions { DataTypes.UInt16, "uint16_t" }, { DataTypes.UInt32, "uint32_t" }, { DataTypes.UInt64, "uint64_t" }, - { DataTypes.Float16, "half" }, + { DataTypes.Float16, "nncase::half" }, + { DataTypes.BFloat16, "nncase::bfloat16" }, { DataTypes.Float32, "float" }, { DataTypes.Float64, "double" }, { DataTypes.Float8E4M3, "nncase::float_e4m3_t" }, diff --git a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml index 33a070c75..2009b54b3 100644 --- a/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml +++ b/modules/Nncase.Modules.CPU/CodeGen/CPU/Templates/thread_main.cpp.cshtml @@ -9,6 +9,8 @@ #include #include "topo_aware_runtime.h" #include +#include +#include @foreach(var (s,i) in Model.Options.MemoryCapacities.Select((s,i) => (s,i)).SkipLast(1)){ @:uint8_t L@(i+1)Data[@(s)]; } diff --git a/modules/Nncase.Modules.CPU/Passes/PassUtility.cs b/modules/Nncase.Modules.CPU/Passes/PassUtility.cs index b35f347d4..60a258ad9 100644 --- a/modules/Nncase.Modules.CPU/Passes/PassUtility.cs +++ b/modules/Nncase.Modules.CPU/Passes/PassUtility.cs @@ -87,15 +87,6 @@ public static bool IsCpuSupported(Op op, Expr expr, IEnumerable arguments, break; - case IR.Tensors.Cast cast: - var inType = arguments.ToArray()[0].CheckedDataType; - if (inType == DataTypes.Float16 || inType == DataTypes.BFloat16 || cast.NewType == DataTypes.Float16 || cast.NewType == DataTypes.BFloat16) - { - return false; - } - - break; - case IR.Tensors.Expand expand: if (arguments.ToArray()[0].CheckedShape.Rank != arguments.ToArray()[1].CheckedShape.Size) { diff --git a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerCast.cs b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerCast.cs index 9a7149b8e..6ad638f78 100644 --- a/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerCast.cs +++ b/modules/Nncase.Modules.CPU/Passes/Rules/CPU/Affine/LowerCast.cs @@ -35,9 +35,9 @@ public LowerCast(string moduleKind = CPUTarget.Kind) _ => true, IsWildcard("input") with { TypePattern = HasShape(s => s.Rank > 0 && s.IsFixed, "tileable") }); - private Expr GetReplace(Cast cast, Expr input) + private Expr GetReplace(Cast cast, Expr call, Expr input) { - var outBuffer = input.CheckedType switch + var outBuffer = call.CheckedType switch { TensorType t => IR.F.Buffer.Uninitialized(t.DType, TIR.MemoryLocation.Data, t.Shape.ToValueArray()), DistributedType dt => IR.F.Buffer.Uninitialized(dt.TensorType.DType, TIR.MemoryLocation.Data, dt.TensorType.Shape.ToValueArray(), dt.NdSBP, dt.Placement), diff --git a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs index f2d1fde19..cdf791907 100644 --- a/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs +++ b/src/Nncase.Tests/Targets/UnitTestCPUKernels.cs @@ -335,12 +335,15 @@ public async Task TestResizeImage(int[] shape, ImageResizeMode resizeMode, int[] } [Theory] - [InlineData(new object[] { new[] { 1, 256, 64, 64 }, Runtime.TypeCode.Float8E4M3, 0 })] - public async Task TestPackCast(int[] shape, Nncase.Runtime.TypeCode type, int count) + [InlineData(new object[] { new[] { 1, 256, 64, 64 }, Runtime.TypeCode.Float8E4M3, Runtime.TypeCode.Float32, 0 })] + [InlineData(new object[] { new[] { 1, 64, 64, 256 }, Runtime.TypeCode.Float16, Runtime.TypeCode.BFloat16, 1 })] + [InlineData(new object[] { new[] { 1, 64, 256, 64 }, Runtime.TypeCode.BFloat16, Runtime.TypeCode.Float16, 2 })] + public async Task TestPackCast(int[] shape, Nncase.Runtime.TypeCode type1, Nncase.Runtime.TypeCode type2, int count) { var input = new Var(new TensorType(DataTypes.Float32, shape)); - var casted = IR.F.Tensors.Cast(input, DataType.FromTypeCode(type)); - var pre = IR.F.Tensors.Cast(casted, DataTypes.Float32); + var casted1 = IR.F.Tensors.Cast(input, DataType.FromTypeCode(type1)); + var casted2 = IR.F.Tensors.Cast(casted1, DataType.FromTypeCode(type2)); + var pre = IR.F.Tensors.Cast(casted2, DataTypes.Float32); var feedDict = new Dictionary() { { input, IR.F.Random.Normal(DataTypes.Float32, 0, 1, 1, shape).Evaluate() },