Skip to content
This repository was archived by the owner on Dec 1, 2021. It is now read-only.

Commit

Permalink
replace network::run() argument types to void* (#1223)
Browse files Browse the repository at this point in the history
  • Loading branch information
lm-kajihara authored Oct 5, 2020
1 parent 8804c87 commit 002f540
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
2 changes: 1 addition & 1 deletion blueoil/converter/templates/include/network.tpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class SYM_PUBLIC Network
void get_input_shape(int32_t *shape);
void get_output_shape(int32_t *shape);

bool run(float *network_input, float *network_output);
bool run(void *network_input, void *network_output);

private:
// declarations
Expand Down
10 changes: 7 additions & 3 deletions blueoil/converter/templates/src/network.tpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ void Network::get_output_shape(int32_t *shape)
std::copy(output_shape, output_shape + output_rank, shape);
}

bool Network::run(float *network_input, float *network_output)
bool Network::run(void *network_input, void *network_output)
{
struct convolution_parameters Conv2D_struct;
struct binary_convolution_parameters binConv2D_struct;
Expand All @@ -293,7 +293,8 @@ bool Network::run(float *network_input, float *network_output)
{{- len -}},
{%- endfor %}
};
TensorView<{{ graph_input.dtype.cpptype() }}, MemoryLayout::{{ graph_input.dimension }}> {{ graph_input.name }}_output(network_input, {{ graph_input.name }}_shape);
TensorView<{{ graph_input.dtype.cpptype() }}, MemoryLayout::{{ graph_input.dimension }}> {{ graph_input.name }}_output(
reinterpret_cast<{{ graph_input.dtype.cpptype() }}*>(network_input), {{ graph_input.name }}_shape);
{{ '\n' -}}

{% for node in graph.non_variables -%}
Expand Down Expand Up @@ -334,7 +335,10 @@ bool Network::run(float *network_input, float *network_output)

// TODO: support multiple output
{% for out_k in graph_output.output_ops.keys() -%}
std::copy({{ graph_output.name }}_{{ out_k }}.data(), {{ graph_output.name }}_{{ out_k }}.data() + {{ graph_output.view.size_in_words_as_cpp }}, network_output);
std::copy(
{{ graph_output.name }}_{{ out_k }}.data(),
{{ graph_output.name }}_{{ out_k }}.data() + {{ graph_output.view.size_in_words_as_cpp }},
reinterpret_cast<decltype({{ graph_output.name }}_{{ out_k }}.data())>(network_output));
{% endfor -%}

return true;
Expand Down
2 changes: 1 addition & 1 deletion blueoil/converter/templates/src/network_c_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ extern "C" __attribute__ ((visibility ("default"))) void network_get_output_shap
nn->get_output_shape(shape);
}

extern "C" __attribute__ ((visibility ("default"))) void network_run(Network *nn, float *input, float *output)
extern "C" __attribute__ ((visibility ("default"))) void network_run(Network *nn, void *input, void *output)
{
nn->run(input, output);
}

0 comments on commit 002f540

Please sign in to comment.