From 2807f693f73c87ac90391bb42202bf4c726c4fef Mon Sep 17 00:00:00 2001 From: Eunju Yang Date: Wed, 2 Oct 2024 14:24:32 +0900 Subject: [PATCH] [ App ] Multi-Input Example Update - This commit is related to issue #2660 - When using multi-inputs, users must feed the data in reverse order due to a known bug that needs fixing. In the current version, the input must be provided in reverse order, which was not shown in the previous example where random data with the same dimensions were used. - To provide a more accurate example to NNTrainer users, I have temporarily updated this example. - Once the issue is handled, further updates will be necessary. Signed-off-by: Eunju Yang --- Applications/Multi_input/Readme.md | 27 +++++++++++++++++++ Applications/Multi_input/jni/main.cpp | 18 ++++++++----- Applications/Multi_input/jni/multi_loader.cpp | 7 ++--- 3 files changed, 42 insertions(+), 10 deletions(-) create mode 100644 Applications/Multi_input/Readme.md diff --git a/Applications/Multi_input/Readme.md b/Applications/Multi_input/Readme.md new file mode 100644 index 0000000000..9476579ae9 --- /dev/null +++ b/Applications/Multi_input/Readme.md @@ -0,0 +1,27 @@ +# Multi_Input example + +- This example demonstrates how to use the `multi_input` layer. +- The NNTrainer supports a network that takes multiple tensors as inputs. +- Users can create multiple `input` layers for the network with their own names and build the network accordingly. +- This code includes an example of training with... + +``` + +-----------+ + | output | + +-----------+ + | + +---------------------------------------------------+ + | flatten | + +---------------------------------------------------+ + | + +---------------------------------------------------+ + | concat0 | + +---------------------------------------------------+ + | | | + +-----------+ +-----------+ +-----------+ + | input 2 | | input 1 | | input 0 | + +-----------+ +-----------+ +-----------+ + +``` + +- **[Note]** Users should feed the multi-input in reverse order because the model is structured in a reversed manner internally. This is a known issue for us, and we plan to address it soon. \ No newline at end of file diff --git a/Applications/Multi_input/jni/main.cpp b/Applications/Multi_input/jni/main.cpp index 6edb781829..23e556f890 100644 --- a/Applications/Multi_input/jni/main.cpp +++ b/Applications/Multi_input/jni/main.cpp @@ -63,14 +63,18 @@ ModelHandle createMultiInputModel() { layers.push_back(createLayer( "input", {withKey("name", "input0"), withKey("input_shape", "1:2:2")})); layers.push_back(createLayer( - "input", {withKey("name", "input1"), withKey("input_shape", "1:2:2")})); + "input", {withKey("name", "input1"), withKey("input_shape", "1:4:2")})); layers.push_back(createLayer( - "input", {withKey("name", "input2"), withKey("input_shape", "1:2:2")})); + "input", {withKey("name", "input2"), withKey("input_shape", "1:8:2")})); layers.push_back( - createLayer("concat", {withKey("name", "concat0"), withKey("axis", "3"), + createLayer("concat", {withKey("name", "concat0"), withKey("axis", "2"), withKey("input_layers", "input0, input1, input2")})); + layers.push_back( + createLayer("flatten", {withKey("name", "flatten0"), + withKey("input_layers", "concat0")})); + layers.push_back(createLayer( "fully_connected", {withKey("unit", 5), withKey("activation", "softmax")})); @@ -123,16 +127,16 @@ std::array createFakeMultiDataGenerator(unsigned int batch_size, unsigned int simulated_data_size) { UserDataType train_data(new nntrainer::util::MultiDataLoader( - {{batch_size, 1, 2, 2}, {batch_size, 1, 2, 2}, {batch_size, 1, 2, 2}}, + {{batch_size, 1, 2, 2}, {batch_size, 1, 4, 2}, {batch_size, 1, 8, 2}}, {{batch_size, 1, 1, 5}}, simulated_data_size)); return {std::move(train_data)}; } int main(int argc, char *argv[]) { - unsigned int total_data_size = 16; - unsigned int batch_size = 2; - unsigned int epoch = 2; + unsigned int total_data_size = 32; + unsigned int batch_size = 4; + unsigned int epoch = 10; std::array user_datas; diff --git a/Applications/Multi_input/jni/multi_loader.cpp b/Applications/Multi_input/jni/multi_loader.cpp index d4ce36c019..3c112758f8 100644 --- a/Applications/Multi_input/jni/multi_loader.cpp +++ b/Applications/Multi_input/jni/multi_loader.cpp @@ -78,10 +78,11 @@ void MultiDataLoader::next(float **input, float **label, bool *last) { }; float **cur_input_tensor = input; + const auto num_input = input_shapes.size() - 1; for (unsigned int i = 0; i < input_shapes.size(); ++i) { - fill_input(*cur_input_tensor, input_shapes.at(i).getFeatureLen(), - indicies[count]); - cur_input_tensor++; + fill_input(*cur_input_tensor, + input_shapes.at(num_input - i).getFeatureLen(), indicies[count]); + ++cur_input_tensor; } float **cur_label_tensor = label;