Skip to content

Commit 0b81901

Browse files
committed
Merge branch 'master' of https://github.com/colgreen/sharpneat
2 parents 4d1c5b6 + c9e5f8b commit 0b81901

File tree

2 files changed

+79
-39
lines changed

2 files changed

+79
-39
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
/* ***************************************************************************
2+
* This file is part of SharpNEAT - Evolution of Neural Networks.
3+
*
4+
* Copyright 2004-2016 Colin Green ([email protected])
5+
*
6+
* Written by Scott DeBoer, 2019 ([email protected])
7+
*
8+
* SharpNEAT is free software; you can redistribute it and/or modify
9+
* it under the terms of The MIT License (MIT).
10+
*
11+
* You should have received a copy of the MIT License
12+
* along with SharpNEAT; if not, see https://opensource.org/licenses/MIT.
13+
*/
14+
using System;
15+
using System.Collections.Generic;
16+
17+
namespace SharpNeat.Network
18+
{
19+
public static class ActivationFunctionRegistry
20+
{
21+
// The default set of activation functions.
22+
private static Dictionary<string, IActivationFunction> __activationFunctionTable = new Dictionary<string, IActivationFunction>()
23+
{
24+
// Bipolar.
25+
{ "BipolarSigmoid", BipolarSigmoid.__DefaultInstance },
26+
{ "BipolarGaussian", BipolarGaussian.__DefaultInstance },
27+
{ "Linear", Linear.__DefaultInstance },
28+
{ "Sine", Sine.__DefaultInstance },
29+
30+
// Unipolar.
31+
{ "ArcSinH", ArcSinH.__DefaultInstance },
32+
{ "ArcTan", ArcTan.__DefaultInstance },
33+
{ "Gaussian", Gaussian.__DefaultInstance },
34+
{ "LeakyReLU", LeakyReLU.__DefaultInstance },
35+
{ "LeakyReLUShifted", LeakyReLUShifted.__DefaultInstance },
36+
{ "LogisticFunction", LogisticFunction.__DefaultInstance },
37+
{ "LogisticFunctionSteep", LogisticFunctionSteep.__DefaultInstance },
38+
{ "MaxMinusOne", MaxMinusOne.__DefaultInstance },
39+
{ "PolynomialApproximantSteep", PolynomialApproximantSteep.__DefaultInstance },
40+
{ "QuadraticSigmoid", QuadraticSigmoid.__DefaultInstance },
41+
{ "ReLU", ReLU.__DefaultInstance },
42+
{ "ScaledELU", ScaledELU.__DefaultInstance },
43+
{ "SoftSignSteep", SoftSignSteep.__DefaultInstance },
44+
{ "SReLU", SReLU.__DefaultInstance },
45+
{ "SReLUShifted", SReLUShifted.__DefaultInstance },
46+
{ "TanH", TanH.__DefaultInstance },
47+
48+
// Radial Basis.
49+
{ "RbfGaussian", RbfGaussian.__DefaultInstance },
50+
};
51+
52+
#region Public Static Methods
53+
54+
/// <summary>
55+
/// Registers a custom activation function in addition to those in the default activation function table.
56+
/// Alows loading of neural nets from XML that use custom activation functions.
57+
/// </summary>
58+
public static void RegisterActivationFunction(IActivationFunction function)
59+
{
60+
if(!__activationFunctionTable.ContainsKey(function.FunctionId)) {
61+
__activationFunctionTable.Add(function.FunctionId, function);
62+
}
63+
}
64+
65+
/// <summary>
66+
/// Gets an IActivationFunction with the given short name.
67+
/// </summary>
68+
public static IActivationFunction GetActivationFunction(string name)
69+
{
70+
if(!__activationFunctionTable.ContainsKey(name)) {
71+
throw new ArgumentException($"Unexpected activation function [{name}]");
72+
}
73+
return __activationFunctionTable[name];
74+
}
75+
76+
#endregion
77+
}
78+
}

Diff for: src/SharpNeatLib/Network/NetworkXmlIO.cs

+1-39
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ public static IActivationFunctionLibrary ReadActivationFunctionLibrary(XmlReader
459459
string fnName = xrSubtree.GetAttribute(__AttrName);
460460

461461
// Lookup function name.
462-
IActivationFunction activationFn = GetActivationFunction(fnName);
462+
IActivationFunction activationFn = ActivationFunctionRegistry.GetActivationFunction(fnName);
463463

464464
// Add new function to our list of functions.
465465
ActivationFunctionInfo fnInfo = new ActivationFunctionInfo(id, selectionProb, activationFn);
@@ -528,44 +528,6 @@ public static string GetNodeTypeString(NodeType nodeType)
528528
throw new ArgumentException($"Unexpected NodeType [{nodeType}]");
529529
}
530530

531-
/// <summary>
532-
/// Gets an IActivationFunction from its short name.
533-
/// </summary>
534-
public static IActivationFunction GetActivationFunction(string name)
535-
{
536-
switch(name)
537-
{
538-
// Bipolar.
539-
case "BipolarGaussian": return BipolarGaussian.__DefaultInstance;
540-
case "BipolarSigmoid": return BipolarSigmoid.__DefaultInstance;
541-
case "Linear": return Linear.__DefaultInstance;
542-
case "Sine": return Sine.__DefaultInstance;
543-
544-
// Unipolar.
545-
case "ArcSinH": return ArcSinH.__DefaultInstance;
546-
case "ArcTan": return ArcTan.__DefaultInstance;
547-
case "Gaussian": return Gaussian.__DefaultInstance;
548-
case "LeakyReLU": return LeakyReLU.__DefaultInstance;
549-
case "LeakyReLUShifted": return LeakyReLUShifted.__DefaultInstance;
550-
case "LogisticFunction": return LogisticFunction.__DefaultInstance;
551-
case "LogisticFunctionSteep": return LogisticFunctionSteep.__DefaultInstance;
552-
case "MaxMinusOne": return MaxMinusOne.__DefaultInstance;
553-
case "PolynomialApproximantSteep": return PolynomialApproximantSteep.__DefaultInstance;
554-
case "QuadraticSigmoid": return QuadraticSigmoid.__DefaultInstance;
555-
case "ReLU": return ReLU.__DefaultInstance;
556-
case "ScaledELU": return ScaledELU.__DefaultInstance;
557-
case "SoftSignSteep": return SoftSignSteep.__DefaultInstance;
558-
case "SReLU": return SReLU.__DefaultInstance;
559-
case "SReLUShifted": return SReLUShifted.__DefaultInstance;
560-
case "TanH": return TanH.__DefaultInstance;
561-
562-
// Radial Basis.
563-
case "RbfGaussian": return RbfGaussian.__DefaultInstance;
564-
565-
}
566-
throw new ArgumentException($"Unexpected activation function [{name}]");
567-
}
568-
569531
#endregion
570532

571533
#region Private Static Methods

0 commit comments

Comments
 (0)