From e87d746a1bbd79f6e998ffddbeeedaef55d3b143 Mon Sep 17 00:00:00 2001 From: MoFtZ Date: Wed, 31 Jan 2024 21:28:39 +1100 Subject: [PATCH] Added workaround for Cuda Test Runner. --- .../Generic/AlgorithmsTestBase.cs | 107 +++++++++--------- Src/ILGPU.Algorithms.Tests/XMathTests.Pow.tt | 42 ++++++- 2 files changed, 94 insertions(+), 55 deletions(-) diff --git a/Src/ILGPU.Algorithms.Tests/Generic/AlgorithmsTestBase.cs b/Src/ILGPU.Algorithms.Tests/Generic/AlgorithmsTestBase.cs index 78c86441f..5ba15cc55 100644 --- a/Src/ILGPU.Algorithms.Tests/Generic/AlgorithmsTestBase.cs +++ b/Src/ILGPU.Algorithms.Tests/Generic/AlgorithmsTestBase.cs @@ -1,6 +1,6 @@ // --------------------------------------------------------------------------------------- // ILGPU Algorithms -// Copyright (c) 2020-2023 ILGPU Project +// Copyright (c) 2020-2024 ILGPU Project // www.ilgpu.net // // File: AlgorithmsTestBase.cs @@ -27,7 +27,7 @@ protected AlgorithmsTestBase(ITestOutputHelper output, TestContext testContext) /// /// Compares two numbers for equality, within a defined tolerance. /// - private class HalfPrecisionComparer + internal class HalfPrecisionComparer : EqualityComparer { public readonly float Margin; @@ -59,7 +59,7 @@ public override int GetHashCode(Half obj) => /// /// Compares two numbers for equality, within a defined tolerance. /// - private class FloatPrecisionComparer + internal class FloatPrecisionComparer : EqualityComparer { public readonly float Margin; @@ -91,7 +91,7 @@ public override int GetHashCode(float obj) => /// /// Compares two numbers for equality, within a defined tolerance. /// - private class DoublePrecisionComparer + internal class DoublePrecisionComparer : EqualityComparer { public readonly double Margin; @@ -123,7 +123,7 @@ public override int GetHashCode(double obj) => /// /// Compares two numbers for equality, within a defined tolerance. /// - private class HalfRelativeErrorComparer + internal class HalfRelativeErrorComparer : EqualityComparer { public readonly float RelativeError; @@ -163,7 +163,7 @@ public override int GetHashCode(Half obj) => /// /// Compares two numbers for equality, within a defined tolerance. /// - private class FloatRelativeErrorComparer + internal class FloatRelativeErrorComparer : EqualityComparer { public readonly float RelativeError; @@ -203,7 +203,7 @@ public override int GetHashCode(float obj) => /// /// Compares two numbers for equality, within a defined tolerance. /// - private class DoubleRelativeErrorComparer + internal class DoubleRelativeErrorComparer : EqualityComparer { public readonly double RelativeError; @@ -245,19 +245,33 @@ public override int GetHashCode(double obj) => /// /// The target buffer. /// The expected values. - /// The acceptable error margin. - public void VerifyWithinPrecision( - ArrayView buffer, - Half[] expected, - uint decimalPlaces) + /// The comparer to use. + public void VerifyUsingComparer( + ArrayView buffer, + T[] expected, + IEqualityComparer comparer) + where T : unmanaged { var data = buffer.GetAsArray(Accelerator.DefaultStream); Assert.Equal(data.Length, expected.Length); - - var comparer = new HalfPrecisionComparer(decimalPlaces); Assert.Equal(expected, data, comparer); } + /// + /// Verifies the contents of the given memory buffer. + /// + /// The target buffer. + /// The expected values. + /// The acceptable error margin. + public void VerifyWithinPrecision( + ArrayView buffer, + Half[] expected, + uint decimalPlaces) => + VerifyUsingComparer( + buffer, + expected, + new HalfPrecisionComparer(decimalPlaces)); + /// /// Verifies the contents of the given memory buffer. /// @@ -267,14 +281,11 @@ public void VerifyWithinPrecision( public void VerifyWithinPrecision( ArrayView buffer, float[] expected, - uint decimalPlaces) - { - var data = buffer.GetAsArray(Accelerator.DefaultStream); - Assert.Equal(data.Length, expected.Length); - - var comparer = new FloatPrecisionComparer(decimalPlaces); - Assert.Equal(expected, data, comparer); - } + uint decimalPlaces) => + VerifyUsingComparer( + buffer, + expected, + new FloatPrecisionComparer(decimalPlaces)); /// /// Verifies the contents of the given memory buffer. @@ -285,14 +296,11 @@ public void VerifyWithinPrecision( public void VerifyWithinPrecision( ArrayView buffer, double[] expected, - uint decimalPlaces) - { - var data = buffer.GetAsArray(Accelerator.DefaultStream); - Assert.Equal(data.Length, expected.Length); - - var comparer = new DoublePrecisionComparer(decimalPlaces); - Assert.Equal(expected, data, comparer); - } + uint decimalPlaces) => + VerifyUsingComparer( + buffer, + expected, + new DoublePrecisionComparer(decimalPlaces)); /// /// Verifies the contents of the given memory buffer. @@ -303,14 +311,11 @@ public void VerifyWithinPrecision( public void VerifyWithinRelativeError( ArrayView buffer, Half[] expected, - double relativeError) - { - var data = buffer.GetAsArray(Accelerator.DefaultStream); - Assert.Equal(data.Length, expected.Length); - - var comparer = new HalfRelativeErrorComparer((float)relativeError); - Assert.Equal(expected, data, comparer); - } + double relativeError) => + VerifyUsingComparer( + buffer, + expected, + new HalfRelativeErrorComparer((float)relativeError)); /// /// Verifies the contents of the given memory buffer. @@ -321,14 +326,11 @@ public void VerifyWithinRelativeError( public void VerifyWithinRelativeError( ArrayView buffer, float[] expected, - double relativeError) - { - var data = buffer.GetAsArray(Accelerator.DefaultStream); - Assert.Equal(data.Length, expected.Length); - - var comparer = new FloatRelativeErrorComparer((float)relativeError); - Assert.Equal(expected, data, comparer); - } + double relativeError) => + VerifyUsingComparer( + buffer, + expected, + new FloatRelativeErrorComparer((float)relativeError)); /// /// Verifies the contents of the given memory buffer. @@ -339,13 +341,10 @@ public void VerifyWithinRelativeError( public void VerifyWithinRelativeError( ArrayView buffer, double[] expected, - double relativeError) - { - var data = buffer.GetAsArray(Accelerator.DefaultStream); - Assert.Equal(data.Length, expected.Length); - - var comparer = new DoubleRelativeErrorComparer(relativeError); - Assert.Equal(expected, data, comparer); - } + double relativeError) => + VerifyUsingComparer( + buffer, + expected, + new DoubleRelativeErrorComparer(relativeError)); } } diff --git a/Src/ILGPU.Algorithms.Tests/XMathTests.Pow.tt b/Src/ILGPU.Algorithms.Tests/XMathTests.Pow.tt index 1b514b8d9..d453fe736 100644 --- a/Src/ILGPU.Algorithms.Tests/XMathTests.Pow.tt +++ b/Src/ILGPU.Algorithms.Tests/XMathTests.Pow.tt @@ -1,6 +1,6 @@ // --------------------------------------------------------------------------------------- // ILGPU Algorithms -// Copyright (c) 2020-2023 ILGPU Project +// Copyright (c) 2020-2024 ILGPU Project // www.ilgpu.net // // File: XMathTests.Pow.tt/XMathTests.Pow.cs @@ -48,6 +48,32 @@ namespace ILGPU.Algorithms.Tests // and ensures a minimum error on each accelerator type. partial class XMathTests { + #region Nested Types + + /// + /// Workaround for different output of LibDevice __nv_pow(double, double) and + /// .NET Math.Pow(double, double) on Cuda Test Runner. + /// + private class CudaPowDoubleRelativeErrorComparer : DoubleRelativeErrorComparer + { + public CudaPowDoubleRelativeErrorComparer(double relativeError) + : base(relativeError) + { } + + public override bool Equals(double x, double y) + { + if ((double.IsPositiveInfinity(x) && double.IsNegativeInfinity(y)) || + (double.IsNegativeInfinity(x) && double.IsPositiveInfinity(y))) + { + return true; + } + + return base.Equals(x, y); + } + } + + #endregion + <# foreach (var function in powFunctions) { #> internal static void <#= function.KernelName #>( Index1D index, @@ -120,10 +146,24 @@ namespace ILGPU.Algorithms.Tests v => Math<#= function.MathSuffix #>.<#= function.Name #>(v.X, v.Y)) .ToArray(); if (Accelerator.AcceleratorType == AcceleratorType.Cuda) +<# + if (function.DataType == "double") { +#> + VerifyUsingComparer( + output.View, + expected, + new CudaPowDoubleRelativeErrorComparer( + (<#= function.DataType #>)<#= function.RelativeError.Cuda #>)); +<# + } else { +#> VerifyWithinRelativeError( output.View, expected, <#= function.RelativeError.Cuda #>); +<# + } +#> else if (Accelerator.AcceleratorType == AcceleratorType.OpenCL) VerifyWithinRelativeError( output.View,