|
| 1 | +using BrightWire.ExecutionGraph; |
| 2 | +using BrightWire.Models; |
| 3 | +using System; |
| 4 | +using System.Collections.Generic; |
| 5 | +using System.IO; |
| 6 | +using System.Linq; |
| 7 | +using System.Text; |
| 8 | +using System.Threading.Tasks; |
| 9 | + |
| 10 | +namespace BrightWire.SampleCode |
| 11 | +{ |
| 12 | + public partial class Program |
| 13 | + { |
| 14 | + const int CLASSIFICATION_COUNT = 6; |
| 15 | + |
| 16 | + static IDataTable _LoadEmotionData(string dataFilePath) |
| 17 | + { |
| 18 | + // read the data as CSV, skipping the header |
| 19 | + using (var reader = new StreamReader(dataFilePath)) { |
| 20 | + while (!reader.EndOfStream) { |
| 21 | + var line = reader.ReadLine(); |
| 22 | + if (line == "@data") |
| 23 | + break; |
| 24 | + } |
| 25 | + return reader.ReadToEnd().ParseCSV(',', false); |
| 26 | + } |
| 27 | + } |
| 28 | + |
| 29 | + /// <summary> |
| 30 | + /// Trains a feed forward neural net on the emotion dataset |
| 31 | + /// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf |
| 32 | + /// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar |
| 33 | + /// </summary> |
| 34 | + /// <param name="dataFilePath"></param> |
| 35 | + public static void MultiLabelSingleClassifier(string dataFilePath) |
| 36 | + { |
| 37 | + var emotionData = _LoadEmotionData(dataFilePath); |
| 38 | + var attributeColumns = Enumerable.Range(0, emotionData.ColumnCount - CLASSIFICATION_COUNT).ToList(); |
| 39 | + var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList(); |
| 40 | + |
| 41 | + // create a new data table with a vector input column and a vector output column |
| 42 | + var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder(); |
| 43 | + dataTableBuilder.AddColumn(ColumnType.Vector, "Attributes"); |
| 44 | + dataTableBuilder.AddColumn(ColumnType.Vector, "Target", isTarget: true); |
| 45 | + emotionData.ForEach(row => { |
| 46 | + var input = FloatVector.Create(row.GetFields<float>(attributeColumns).ToArray()); |
| 47 | + var target = FloatVector.Create(row.GetFields<float>(classificationColumns).ToArray()); |
| 48 | + dataTableBuilder.Add(input, target); |
| 49 | + return true; |
| 50 | + }); |
| 51 | + var data = dataTableBuilder.Build().Split(0); |
| 52 | + |
| 53 | + // train a neural network |
| 54 | + using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) { |
| 55 | + var graph = new GraphFactory(lap); |
| 56 | + |
| 57 | + // binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets |
| 58 | + var errorMetric = graph.ErrorMetric.BinaryClassification; |
| 59 | + |
| 60 | + // configure the network properties |
| 61 | + graph.CurrentPropertySet |
| 62 | + .Use(graph.GradientDescent.Adam) |
| 63 | + .Use(graph.WeightInitialisation.Xavier) |
| 64 | + ; |
| 65 | + |
| 66 | + // create a training engine |
| 67 | + const float TRAINING_RATE = 0.3f; |
| 68 | + var trainingData = graph.CreateDataSource(data.Training); |
| 69 | + var testData = trainingData.CloneWith(data.Test); |
| 70 | + var engine = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 128); |
| 71 | + |
| 72 | + // build the network |
| 73 | + const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000; |
| 74 | + var network = graph.Connect(engine) |
| 75 | + .AddFeedForward(HIDDEN_LAYER_SIZE) |
| 76 | + .Add(graph.SigmoidActivation()) |
| 77 | + .AddDropOut(dropOutPercentage: 0.5f) |
| 78 | + .AddFeedForward(engine.DataSource.OutputSize) |
| 79 | + .Add(graph.SigmoidActivation()) |
| 80 | + .AddBackpropagation(errorMetric) |
| 81 | + ; |
| 82 | + |
| 83 | + // train the network |
| 84 | + engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 50); |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + /// <summary> |
| 89 | + /// Trains multiple classifiers on the emotion data set |
| 90 | + /// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf |
| 91 | + /// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar |
| 92 | + /// </summary> |
| 93 | + /// <param name="dataFilePath"></param> |
| 94 | + public static void MultiLabelMultiClassifiers(string dataFilePath) |
| 95 | + { |
| 96 | + var emotionData = _LoadEmotionData(dataFilePath); |
| 97 | + var attributeCount = emotionData.ColumnCount - CLASSIFICATION_COUNT; |
| 98 | + var attributeColumns = Enumerable.Range(0, attributeCount).ToList(); |
| 99 | + var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList(); |
| 100 | + var classificationLabel = new[] { |
| 101 | + "amazed-suprised", |
| 102 | + "happy-pleased", |
| 103 | + "relaxing-calm", |
| 104 | + "quiet-still", |
| 105 | + "sad-lonely", |
| 106 | + "angry-aggresive" |
| 107 | + }; |
| 108 | + |
| 109 | + // create six separate datasets to train, each with a separate classification column |
| 110 | + var dataSets = Enumerable.Range(attributeCount, CLASSIFICATION_COUNT).Select(targetIndex => { |
| 111 | + var dataTableBuider = BrightWireProvider.CreateDataTableBuilder(); |
| 112 | + for (var i = 0; i < attributeCount; i++) |
| 113 | + dataTableBuider.AddColumn(ColumnType.Float); |
| 114 | + dataTableBuider.AddColumn(ColumnType.Float, "", true); |
| 115 | + |
| 116 | + return emotionData.Project(row => row.GetFields<float>(attributeColumns) |
| 117 | + .Concat(new[] { row.GetField<float>(targetIndex) }) |
| 118 | + .Cast<object>() |
| 119 | + .ToList() |
| 120 | + ); |
| 121 | + }).Select(ds => ds.Split(0)).ToList(); |
| 122 | + |
| 123 | + // train classifiers on each training set |
| 124 | + using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) { |
| 125 | + var graph = new GraphFactory(lap); |
| 126 | + |
| 127 | + // binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets |
| 128 | + var errorMetric = graph.ErrorMetric.BinaryClassification; |
| 129 | + |
| 130 | + // configure the network properties |
| 131 | + graph.CurrentPropertySet |
| 132 | + .Use(graph.GradientDescent.Adam) |
| 133 | + .Use(graph.WeightInitialisation.Xavier) |
| 134 | + ; |
| 135 | + |
| 136 | + for (var i = 0; i < CLASSIFICATION_COUNT; i++) { |
| 137 | + var trainingSet = dataSets[i].Training; |
| 138 | + var testSet = dataSets[i].Test; |
| 139 | + Console.WriteLine("Training on {0}", classificationLabel[i]); |
| 140 | + |
| 141 | + // train and evaluate a naive bayes classifier |
| 142 | + var naiveBayes = trainingSet.TrainNaiveBayes().CreateClassifier(); |
| 143 | + Console.WriteLine("\tNaive bayes accuracy: {0:P}", testSet |
| 144 | + .Classify(naiveBayes) |
| 145 | + .Average(d => d.Row.GetField<string>(attributeCount) == d.Classification ? 1.0 : 0.0) |
| 146 | + ); |
| 147 | + |
| 148 | + // train a logistic regression classifier |
| 149 | + var logisticRegression = trainingSet |
| 150 | + .TrainLogisticRegression(lap, 2500, 0.25f, 0.01f) |
| 151 | + .CreatePredictor(lap) |
| 152 | + .ConvertToRowClassifier(attributeColumns) |
| 153 | + ; |
| 154 | + Console.WriteLine("\tLogistic regression accuracy: {0:P}", testSet |
| 155 | + .Classify(logisticRegression) |
| 156 | + .Average(d => d.Row.GetField<string>(attributeCount) == d.Classification ? 1.0 : 0.0) |
| 157 | + ); |
| 158 | + |
| 159 | + // train and evaluate k nearest neighbours |
| 160 | + var knn = trainingSet.TrainKNearestNeighbours().CreateClassifier(lap, 10); |
| 161 | + Console.WriteLine("\tK nearest neighbours accuracy: {0:P}", testSet |
| 162 | + .Classify(knn) |
| 163 | + .Average(d => d.Row.GetField<string>(attributeCount) == d.Classification ? 1.0 : 0.0) |
| 164 | + ); |
| 165 | + |
| 166 | + // create a training engine |
| 167 | + const float TRAINING_RATE = 0.1f; |
| 168 | + var trainingData = graph.CreateDataSource(trainingSet); |
| 169 | + var testData = trainingData.CloneWith(testSet); |
| 170 | + var engine = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 64); |
| 171 | + |
| 172 | + // build the network |
| 173 | + const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000; |
| 174 | + var network = graph.Connect(engine) |
| 175 | + .AddFeedForward(HIDDEN_LAYER_SIZE) |
| 176 | + .Add(graph.SigmoidActivation()) |
| 177 | + .AddDropOut(dropOutPercentage: 0.5f) |
| 178 | + .AddFeedForward(engine.DataSource.OutputSize) |
| 179 | + .Add(graph.SigmoidActivation()) |
| 180 | + .AddBackpropagation(errorMetric) |
| 181 | + ; |
| 182 | + |
| 183 | + // train the network |
| 184 | + engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 200); |
| 185 | + } |
| 186 | + } |
| 187 | + } |
| 188 | + } |
| 189 | +} |
0 commit comments