Skip to content

Commit 73ac016

Browse files
author
Jack Dermody
committed
added multi label examples
1 parent 161c8b5 commit 73ac016

File tree

4 files changed

+194
-3
lines changed

4 files changed

+194
-3
lines changed

Diff for: BrightWire.Source/ExecutionGraph/WireBuilder.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,8 @@ public WireBuilder IncrementSizeBy(int delta)
9191

9292
void _SetNode(INode node)
9393
{
94-
if (_node != null)
95-
_node.Output.Add(new WireToNode(node));
96-
_node = node;
94+
_node?.Output.Add(new WireToNode(node));
95+
_node = node;
9796
}
9897

9998
/// <summary>

Diff for: SampleCode/MultiLabel.cs

+189
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
using BrightWire.ExecutionGraph;
2+
using BrightWire.Models;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.IO;
6+
using System.Linq;
7+
using System.Text;
8+
using System.Threading.Tasks;
9+
10+
namespace BrightWire.SampleCode
11+
{
12+
public partial class Program
13+
{
14+
const int CLASSIFICATION_COUNT = 6;
15+
16+
static IDataTable _LoadEmotionData(string dataFilePath)
17+
{
18+
// read the data as CSV, skipping the header
19+
using (var reader = new StreamReader(dataFilePath)) {
20+
while (!reader.EndOfStream) {
21+
var line = reader.ReadLine();
22+
if (line == "@data")
23+
break;
24+
}
25+
return reader.ReadToEnd().ParseCSV(',', false);
26+
}
27+
}
28+
29+
/// <summary>
30+
/// Trains a feed forward neural net on the emotion dataset
31+
/// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf
32+
/// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar
33+
/// </summary>
34+
/// <param name="dataFilePath"></param>
35+
public static void MultiLabelSingleClassifier(string dataFilePath)
36+
{
37+
var emotionData = _LoadEmotionData(dataFilePath);
38+
var attributeColumns = Enumerable.Range(0, emotionData.ColumnCount - CLASSIFICATION_COUNT).ToList();
39+
var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList();
40+
41+
// create a new data table with a vector input column and a vector output column
42+
var dataTableBuilder = BrightWireProvider.CreateDataTableBuilder();
43+
dataTableBuilder.AddColumn(ColumnType.Vector, "Attributes");
44+
dataTableBuilder.AddColumn(ColumnType.Vector, "Target", isTarget: true);
45+
emotionData.ForEach(row => {
46+
var input = FloatVector.Create(row.GetFields<float>(attributeColumns).ToArray());
47+
var target = FloatVector.Create(row.GetFields<float>(classificationColumns).ToArray());
48+
dataTableBuilder.Add(input, target);
49+
return true;
50+
});
51+
var data = dataTableBuilder.Build().Split(0);
52+
53+
// train a neural network
54+
using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
55+
var graph = new GraphFactory(lap);
56+
57+
// binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets
58+
var errorMetric = graph.ErrorMetric.BinaryClassification;
59+
60+
// configure the network properties
61+
graph.CurrentPropertySet
62+
.Use(graph.GradientDescent.Adam)
63+
.Use(graph.WeightInitialisation.Xavier)
64+
;
65+
66+
// create a training engine
67+
const float TRAINING_RATE = 0.3f;
68+
var trainingData = graph.CreateDataSource(data.Training);
69+
var testData = trainingData.CloneWith(data.Test);
70+
var engine = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 128);
71+
72+
// build the network
73+
const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000;
74+
var network = graph.Connect(engine)
75+
.AddFeedForward(HIDDEN_LAYER_SIZE)
76+
.Add(graph.SigmoidActivation())
77+
.AddDropOut(dropOutPercentage: 0.5f)
78+
.AddFeedForward(engine.DataSource.OutputSize)
79+
.Add(graph.SigmoidActivation())
80+
.AddBackpropagation(errorMetric)
81+
;
82+
83+
// train the network
84+
engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 50);
85+
}
86+
}
87+
88+
/// <summary>
89+
/// Trains multiple classifiers on the emotion data set
90+
/// http://lpis.csd.auth.gr/publications/tsoumakas-ismir08.pdf
91+
/// The data files can be downloaded from https://downloads.sourceforge.net/project/mulan/datasets/emotions.rar
92+
/// </summary>
93+
/// <param name="dataFilePath"></param>
94+
public static void MultiLabelMultiClassifiers(string dataFilePath)
95+
{
96+
var emotionData = _LoadEmotionData(dataFilePath);
97+
var attributeCount = emotionData.ColumnCount - CLASSIFICATION_COUNT;
98+
var attributeColumns = Enumerable.Range(0, attributeCount).ToList();
99+
var classificationColumns = Enumerable.Range(emotionData.ColumnCount - CLASSIFICATION_COUNT, CLASSIFICATION_COUNT).ToList();
100+
var classificationLabel = new[] {
101+
"amazed-suprised",
102+
"happy-pleased",
103+
"relaxing-calm",
104+
"quiet-still",
105+
"sad-lonely",
106+
"angry-aggresive"
107+
};
108+
109+
// create six separate datasets to train, each with a separate classification column
110+
var dataSets = Enumerable.Range(attributeCount, CLASSIFICATION_COUNT).Select(targetIndex => {
111+
var dataTableBuider = BrightWireProvider.CreateDataTableBuilder();
112+
for (var i = 0; i < attributeCount; i++)
113+
dataTableBuider.AddColumn(ColumnType.Float);
114+
dataTableBuider.AddColumn(ColumnType.Float, "", true);
115+
116+
return emotionData.Project(row => row.GetFields<float>(attributeColumns)
117+
.Concat(new[] { row.GetField<float>(targetIndex) })
118+
.Cast<object>()
119+
.ToList()
120+
);
121+
}).Select(ds => ds.Split(0)).ToList();
122+
123+
// train classifiers on each training set
124+
using (var lap = BrightWireProvider.CreateLinearAlgebra(false)) {
125+
var graph = new GraphFactory(lap);
126+
127+
// binary classification rounds each output to 0 or 1 and compares each output against the binary classification targets
128+
var errorMetric = graph.ErrorMetric.BinaryClassification;
129+
130+
// configure the network properties
131+
graph.CurrentPropertySet
132+
.Use(graph.GradientDescent.Adam)
133+
.Use(graph.WeightInitialisation.Xavier)
134+
;
135+
136+
for (var i = 0; i < CLASSIFICATION_COUNT; i++) {
137+
var trainingSet = dataSets[i].Training;
138+
var testSet = dataSets[i].Test;
139+
Console.WriteLine("Training on {0}", classificationLabel[i]);
140+
141+
// train and evaluate a naive bayes classifier
142+
var naiveBayes = trainingSet.TrainNaiveBayes().CreateClassifier();
143+
Console.WriteLine("\tNaive bayes accuracy: {0:P}", testSet
144+
.Classify(naiveBayes)
145+
.Average(d => d.Row.GetField<string>(attributeCount) == d.Classification ? 1.0 : 0.0)
146+
);
147+
148+
// train a logistic regression classifier
149+
var logisticRegression = trainingSet
150+
.TrainLogisticRegression(lap, 2500, 0.25f, 0.01f)
151+
.CreatePredictor(lap)
152+
.ConvertToRowClassifier(attributeColumns)
153+
;
154+
Console.WriteLine("\tLogistic regression accuracy: {0:P}", testSet
155+
.Classify(logisticRegression)
156+
.Average(d => d.Row.GetField<string>(attributeCount) == d.Classification ? 1.0 : 0.0)
157+
);
158+
159+
// train and evaluate k nearest neighbours
160+
var knn = trainingSet.TrainKNearestNeighbours().CreateClassifier(lap, 10);
161+
Console.WriteLine("\tK nearest neighbours accuracy: {0:P}", testSet
162+
.Classify(knn)
163+
.Average(d => d.Row.GetField<string>(attributeCount) == d.Classification ? 1.0 : 0.0)
164+
);
165+
166+
// create a training engine
167+
const float TRAINING_RATE = 0.1f;
168+
var trainingData = graph.CreateDataSource(trainingSet);
169+
var testData = trainingData.CloneWith(testSet);
170+
var engine = graph.CreateTrainingEngine(trainingData, TRAINING_RATE, 64);
171+
172+
// build the network
173+
const int HIDDEN_LAYER_SIZE = 64, TRAINING_ITERATIONS = 2000;
174+
var network = graph.Connect(engine)
175+
.AddFeedForward(HIDDEN_LAYER_SIZE)
176+
.Add(graph.SigmoidActivation())
177+
.AddDropOut(dropOutPercentage: 0.5f)
178+
.AddFeedForward(engine.DataSource.OutputSize)
179+
.Add(graph.SigmoidActivation())
180+
.AddBackpropagation(errorMetric)
181+
;
182+
183+
// train the network
184+
engine.Train(TRAINING_ITERATIONS, testData, errorMetric, null, 200);
185+
}
186+
}
187+
}
188+
}
189+
}

Diff for: SampleCode/Program.cs

+2
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ static void Main(string[] args)
9393
//SimpleLinearTest();
9494
//PredictBicyclesWithLinearModel(@"D:\data\bikesharing\hour.csv");
9595
//PredictBicyclesWithNeuralNetwork(@"D:\data\bikesharing\hour.csv");
96+
//MultiLabelSingleClassifier(@"d:\data\emotions\emotions.arff");
97+
//MultiLabelMultiClassifiers(@"d:\data\emotions\emotions.arff");
9698
}
9799
}
98100
}

Diff for: SampleCode/SampleCode.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
<Compile Include="IrisClassification.cs" />
6565
<Compile Include="IrisClustering.cs" />
6666
<Compile Include="MarkovChains.cs" />
67+
<Compile Include="MultiLabel.cs" />
6768
<Compile Include="Program.cs" />
6869
<Compile Include="Properties\AssemblyInfo.cs" />
6970
<Compile Include="MNIST.cs" />

0 commit comments

Comments
 (0)