Skip to content

Commit 6a2d7e1

Browse files
authored
Merge pull request #1263 from ASolomatin/master
fix: Support for training a multi-input model using a dataset.
2 parents 7fb73cd + 93dda17 commit 6a2d7e1

File tree

3 files changed

+107
-2
lines changed

3 files changed

+107
-2
lines changed

Diff for: src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

+13-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,19 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
112112
Steps = data_handler.Inferredsteps
113113
});
114114

115-
return evaluate(data_handler, callbacks, is_val, test_function);
115+
Func<DataHandler, OwnedIterator, Dictionary<string, float>> testFunction;
116+
117+
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
118+
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
119+
{
120+
testFunction = test_step_multi_inputs_function;
121+
}
122+
else
123+
{
124+
testFunction = test_function;
125+
}
126+
127+
return evaluate(data_handler, callbacks, is_val, testFunction);
116128
}
117129

118130
/// <summary>

Diff for: src/TensorFlowNET.Keras/Engine/Model.Fit.cs

+12-1
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,20 @@ public ICallback fit(IDatasetV2 dataset,
179179
StepsPerExecution = _steps_per_execution
180180
});
181181

182+
Func<DataHandler, OwnedIterator, Dictionary<string, float>> trainStepFunction;
183+
184+
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
185+
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
186+
{
187+
trainStepFunction = train_step_multi_inputs_function;
188+
}
189+
else
190+
{
191+
trainStepFunction = train_step_function;
192+
}
182193

183194
return FitInternal(data_handler, epochs, validation_step, verbose, callbacks, validation_data: validation_data,
184-
train_step_func: train_step_function);
195+
train_step_func: trainStepFunction);
185196
}
186197

187198
History FitInternal(DataHandler data_handler, int epochs, int validation_step, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,

Diff for: test/TensorFlowNET.Keras.UnitTest/MultiInputModelTest.cs

+82
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using Tensorflow.Keras.Optimizers;
44
using Tensorflow.NumPy;
5+
using static Tensorflow.Binding;
56
using static Tensorflow.KerasApi;
67

78
namespace Tensorflow.Keras.UnitTest
@@ -54,10 +55,91 @@ public void LeNetModel()
5455
var x = new NDArray[] { x1, x2 };
5556
model.fit(x, dataset.Train.Labels, batch_size: 8, epochs: 3);
5657

58+
x1 = x1["0:8"];
59+
x2 = x1;
60+
61+
x = new NDArray[] { x1, x2 };
62+
var y = dataset.Train.Labels["0:8"];
63+
(model as Engine.Model).evaluate(x, y);
64+
5765
x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
5866
x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
5967
var pred = model.predict((x1, x2));
6068
Console.WriteLine(pred);
6169
}
70+
71+
[TestMethod]
72+
public void LeNetModelDataset()
73+
{
74+
var inputs = keras.Input((28, 28, 1));
75+
var conv1 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs);
76+
var pool1 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv1);
77+
var conv2 = keras.layers.Conv2D(32, (3, 3), activation: "relu", padding: "same").Apply(pool1);
78+
var pool2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2);
79+
var flat1 = keras.layers.Flatten().Apply(pool2);
80+
81+
var inputs_2 = keras.Input((28, 28, 1));
82+
var conv1_2 = keras.layers.Conv2D(16, (3, 3), activation: "relu", padding: "same").Apply(inputs_2);
83+
var pool1_2 = keras.layers.MaxPooling2D((4, 4), 4).Apply(conv1_2);
84+
var conv2_2 = keras.layers.Conv2D(32, (1, 1), activation: "relu", padding: "same").Apply(pool1_2);
85+
var pool2_2 = keras.layers.MaxPooling2D((2, 2), 2).Apply(conv2_2);
86+
var flat1_2 = keras.layers.Flatten().Apply(pool2_2);
87+
88+
var concat = keras.layers.Concatenate().Apply((flat1, flat1_2));
89+
var dense1 = keras.layers.Dense(512, activation: "relu").Apply(concat);
90+
var dense2 = keras.layers.Dense(128, activation: "relu").Apply(dense1);
91+
var dense3 = keras.layers.Dense(10, activation: "relu").Apply(dense2);
92+
var output = keras.layers.Softmax(-1).Apply(dense3);
93+
94+
var model = keras.Model((inputs, inputs_2), output);
95+
model.summary();
96+
97+
var data_loader = new MnistModelLoader();
98+
99+
var dataset = data_loader.LoadAsync(new ModelLoadSetting
100+
{
101+
TrainDir = "mnist",
102+
OneHot = false,
103+
ValidationSize = 59900,
104+
}).Result;
105+
106+
var loss = keras.losses.SparseCategoricalCrossentropy();
107+
var optimizer = new Adam(0.001f);
108+
model.compile(optimizer, loss, new string[] { "accuracy" });
109+
110+
NDArray x1 = np.reshape(dataset.Train.Data, (dataset.Train.Data.shape[0], 28, 28, 1));
111+
112+
var multiInputDataset = tf.data.Dataset.zip(
113+
tf.data.Dataset.from_tensor_slices(x1),
114+
tf.data.Dataset.from_tensor_slices(x1),
115+
tf.data.Dataset.from_tensor_slices(dataset.Train.Labels)
116+
).batch(8);
117+
multiInputDataset.FirstInputTensorCount = 2;
118+
119+
model.fit(multiInputDataset, epochs: 3);
120+
121+
x1 = x1["0:8"];
122+
123+
multiInputDataset = tf.data.Dataset.zip(
124+
tf.data.Dataset.from_tensor_slices(x1),
125+
tf.data.Dataset.from_tensor_slices(x1),
126+
tf.data.Dataset.from_tensor_slices(dataset.Train.Labels["0:8"])
127+
).batch(8);
128+
multiInputDataset.FirstInputTensorCount = 2;
129+
130+
(model as Engine.Model).evaluate(multiInputDataset);
131+
132+
x1 = np.ones((1, 28, 28, 1), TF_DataType.TF_FLOAT);
133+
var x2 = np.zeros((1, 28, 28, 1), TF_DataType.TF_FLOAT);
134+
135+
multiInputDataset = tf.data.Dataset.zip(
136+
tf.data.Dataset.from_tensor_slices(x1),
137+
tf.data.Dataset.from_tensor_slices(x2)
138+
).batch(8);
139+
multiInputDataset.FirstInputTensorCount = 2;
140+
141+
var pred = model.predict(multiInputDataset);
142+
Console.WriteLine(pred);
143+
}
62144
}
63145
}

0 commit comments

Comments
 (0)