Skip to content

Commit 63ed7c1

Browse files
committed
FunctionRegression domain: ported performance improvements from SharpNeat 4.x branch.
1 parent 1e180a2 commit 63ed7c1

File tree

9 files changed

+56
-38
lines changed

9 files changed

+56
-38
lines changed

src/SharpNeatDomains/FunctionRegression/FnRegressionEvaluator.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
* You should have received a copy of the MIT License
1010
* along with SharpNEAT; if not, see https://opensource.org/licenses/MIT.
1111
*/
12-
using System;
12+
using Redzen;
1313
using SharpNeat.Core;
1414
using SharpNeat.Phenomes;
1515

@@ -103,11 +103,11 @@ public FitnessInfo Evaluate(IBlackBox box)
103103
FnRegressionUtils.CalcGradients(_paramSamplingInfo, yArr, gradientArr);
104104

105105
// Calc y position mean squared error (MSE), and apply weighting.
106-
double yMse = FnRegressionUtils.CalcMeanSquaredError(yArr, _yArrTarget);
106+
double yMse = MathArrayUtils.MeanSquaredDelta(yArr, _yArrTarget);
107107
yMse *= _yMseWeight;
108108

109109
// Calc gradient mean squared error.
110-
double gradientMse = FnRegressionUtils.CalcMeanSquaredError(gradientArr, _gradientArrTarget);
110+
double gradientMse = MathArrayUtils.MeanSquaredDelta(gradientArr, _gradientArrTarget);
111111
gradientMse *= _gradientMseWeight;
112112

113113
// Calc fitness as the inverse of MSE (higher value is fitter).

src/SharpNeatDomains/FunctionRegression/FnRegressionUtils.cs

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
* along with SharpNEAT; if not, see https://opensource.org/licenses/MIT.
1111
*/
1212
using System;
13-
using System.Diagnostics;
13+
using System.Numerics;
1414

1515
namespace SharpNeat.Domains.FunctionRegression
1616
{
@@ -37,37 +37,59 @@ public static void CalcFunctionMidAndScale(IFunction fn, ParamSamplingInfo param
3737
mid = ((min+max) / 2.0);
3838
}
3939

40-
public static void CalcGradients (ParamSamplingInfo paramSamplingInfo, double[] yArr, double[] gradientArr)
40+
public static void CalcGradients(
41+
ParamSamplingInfo paramSamplingInfo,
42+
double[] yArr,
43+
double[] gradientArr)
4144
{
45+
// Notes.
46+
// The gradient at a sample point is approximated by taking the gradient of the line between the two
47+
// sample points either side of that point. For the first and last sample points we take the gradient
48+
// of the line between the sample point and its single adjacent sample point (as an alternative we could
49+
// sample an additional point at each end that doesn't get used for the function regression evaluation.
50+
//
51+
// This approach is rather crude, but fast. A better approach might be to do a polynomial regression on
52+
// the sample point and its nearest two adjacent samples, and then take the gradient of the polynomial
53+
// regression at the required point; obviously that would required more computational work to do so may
54+
// not be beneficial in the overall context of an evolutionary algorithm.
55+
//
56+
// Furthermore, the difference between this gradient approximation and the true gradient decreases with
57+
// increases sample density, therefore this is a reasonable approach *if* the sample density is
58+
// sufficiently high.
59+
4260
// Handle the end points as special cases.
4361
// First point.
4462
double[] xArr = paramSamplingInfo._xArr;
4563
gradientArr[0] = CalcGradient(xArr[0], yArr[0], xArr[1], yArr[1]);
4664

4765
// Intermediate points.
48-
for(int i=1; i< xArr.Length-1; i++) {
49-
gradientArr[i] = CalcGradient(xArr[i-1], yArr[i-1], xArr[i+1], yArr[i+1]);
50-
}
66+
int width = Vector<double>.Count;
67+
int i=1;
68+
for(; i < xArr.Length - width - 1; i += width)
69+
{
70+
// Calc a block of x deltas.
71+
var vecLeft = new Vector<double>(xArr, i - 1);
72+
var vecRight = new Vector<double>(xArr, i + 1);
73+
var xVecDelta = vecRight - vecLeft;
5174

52-
// Last point.
53-
int lastIdx = xArr.Length-1;
54-
gradientArr[lastIdx] = CalcGradient(xArr[lastIdx-1], yArr[lastIdx-1], xArr[lastIdx], yArr[lastIdx]);
55-
}
75+
// Calc a block of y deltas.
76+
vecLeft = new Vector<double>(yArr, i - 1);
77+
vecRight = new Vector<double>(yArr, i + 1);
78+
var yVecDelta = vecRight - vecLeft;
5679

57-
public static double CalcMeanSquaredError(double[] a, double[] b)
58-
{
59-
Debug.Assert(a.Length == b.Length);
80+
// Divide the y's by x's to obtain the gradients.
81+
var gradientVec = yVecDelta / xVecDelta;
6082

61-
// Calc sum(squared error).
62-
double total = 0.0;
63-
for(int i=0; i<a.Length; i++)
64-
{
65-
double err = a[i] - b[i];
66-
total += err * err;
83+
gradientVec.CopyTo(gradientArr, i);
84+
}
85+
86+
// Calc gradients for remaining intermediate points (if any).
87+
for (; i < xArr.Length - 1; i++) {
88+
gradientArr[i] = CalcGradient(xArr[i - 1], yArr[i - 1], xArr[i + 1], yArr[i + 1]);
6789
}
6890

69-
// Calculate mean squared error (MSE).
70-
return total / a.Length;
91+
// Last point.
92+
gradientArr[i] = CalcGradient(xArr[i - 1], yArr[i - 1], xArr[i], yArr[i]);
7193
}
7294

7395
#endregion
@@ -76,11 +98,7 @@ public static double CalcMeanSquaredError(double[] a, double[] b)
7698

7799
private static double CalcGradient(double x1, double y1, double x2, double y2)
78100
{
79-
double ydiff = y2 - y1;
80-
if(ydiff == 0.0) {
81-
return 0.0;
82-
}
83-
return ydiff / (x2-x1);
101+
return (y2 - y1) / (x2 - x1);
84102
}
85103

86104
#endregion

src/SharpNeatDomains/SharpNeatDomains.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
<HintPath>..\packages\log4net.2.0.8\lib\net45-full\log4net.dll</HintPath>
4646
<Private>True</Private>
4747
</Reference>
48-
<Reference Include="Redzen, Version=8.0.0.0, Culture=neutral, PublicKeyToken=182843a4be0a74f7, processorArchitecture=MSIL">
49-
<HintPath>..\packages\Redzen.8.0.0\lib\netstandard2.0\Redzen.dll</HintPath>
48+
<Reference Include="Redzen, Version=9.0.0.0, Culture=neutral, PublicKeyToken=182843a4be0a74f7, processorArchitecture=MSIL">
49+
<HintPath>..\packages\Redzen.9.0.0-alpha1\lib\netstandard2.0\Redzen.dll</HintPath>
5050
</Reference>
5151
<Reference Include="System" />
5252
<Reference Include="System.Buffers, Version=4.0.3.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=MSIL">

src/SharpNeatDomains/packages.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<?xml version="1.0" encoding="utf-8"?>
22
<packages>
33
<package id="log4net" version="2.0.8" targetFramework="net461" />
4-
<package id="Redzen" version="8.0.0" targetFramework="net472" />
4+
<package id="Redzen" version="9.0.0-alpha1" targetFramework="net48" />
55
<package id="System.Buffers" version="4.5.0" targetFramework="net471" />
66
<package id="System.Memory" version="4.5.2" targetFramework="net471" />
77
<package id="System.Numerics.Vectors" version="4.5.0" targetFramework="net471" />

src/SharpNeatDomainsExtra/SharpNeatDomainsExtra.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@
4343
<HintPath>..\packages\log4net.2.0.8\lib\net45-full\log4net.dll</HintPath>
4444
<Private>True</Private>
4545
</Reference>
46-
<Reference Include="Redzen, Version=8.0.0.0, Culture=neutral, PublicKeyToken=182843a4be0a74f7, processorArchitecture=MSIL">
47-
<HintPath>..\packages\Redzen.8.0.0\lib\netstandard2.0\Redzen.dll</HintPath>
46+
<Reference Include="Redzen, Version=9.0.0.0, Culture=neutral, PublicKeyToken=182843a4be0a74f7, processorArchitecture=MSIL">
47+
<HintPath>..\packages\Redzen.9.0.0-alpha1\lib\netstandard2.0\Redzen.dll</HintPath>
4848
</Reference>
4949
<Reference Include="System" />
5050
<Reference Include="System.Buffers, Version=4.0.3.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=MSIL">

src/SharpNeatDomainsExtra/packages.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
<packages>
33
<package id="Box2DX" version="2.0.2.0" targetFramework="net461" />
44
<package id="log4net" version="2.0.8" targetFramework="net461" />
5-
<package id="Redzen" version="8.0.0" targetFramework="net472" />
5+
<package id="Redzen" version="9.0.0-alpha1" targetFramework="net48" />
66
<package id="System.Buffers" version="4.5.0" targetFramework="net471" />
77
<package id="System.Memory" version="4.5.2" targetFramework="net471" />
88
<package id="System.Numerics.Vectors" version="4.5.0" targetFramework="net471" />

src/SharpNeatGUI/SharpNeatGUI.csproj

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@
4646
<HintPath>..\packages\log4net.2.0.8\lib\net45-full\log4net.dll</HintPath>
4747
<Private>True</Private>
4848
</Reference>
49-
<Reference Include="Redzen, Version=8.0.0.0, Culture=neutral, PublicKeyToken=182843a4be0a74f7, processorArchitecture=MSIL">
50-
<HintPath>..\packages\Redzen.8.0.0\lib\netstandard2.0\Redzen.dll</HintPath>
49+
<Reference Include="Redzen, Version=9.0.0.0, Culture=neutral, PublicKeyToken=182843a4be0a74f7, processorArchitecture=MSIL">
50+
<HintPath>..\packages\Redzen.9.0.0-alpha1\lib\netstandard2.0\Redzen.dll</HintPath>
5151
</Reference>
5252
<Reference Include="System" />
5353
<Reference Include="System.Buffers, Version=4.0.3.0, Culture=neutral, PublicKeyToken=cc7b13ffcd2ddd51, processorArchitecture=MSIL">

src/SharpNeatGUI/packages.config

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<?xml version="1.0" encoding="utf-8"?>
22
<packages>
33
<package id="log4net" version="2.0.8" targetFramework="net461" />
4-
<package id="Redzen" version="8.0.0" targetFramework="net472" />
4+
<package id="Redzen" version="9.0.0-alpha1" targetFramework="net48" />
55
<package id="System.Buffers" version="4.5.0" targetFramework="net471" />
66
<package id="System.Memory" version="4.5.2" targetFramework="net471" />
77
<package id="System.Numerics.Vectors" version="4.5.0" targetFramework="net471" />

src/SharpNeatLib/SharpNeatLib.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
<ItemGroup>
2121
<PackageReference Include="log4net" Version="2.0.8" />
22-
<PackageReference Include="Redzen" Version="8.0.0" />
22+
<PackageReference Include="Redzen" Version="9.0.0-alpha1" />
2323
</ItemGroup>
2424

2525
<ItemGroup>

0 commit comments

Comments
 (0)