Skip to content

Commit

Permalink
styles applied
Browse files Browse the repository at this point in the history
  • Loading branch information
mbencer committed Feb 21, 2025
1 parent 5f4135b commit 4361786
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 187 deletions.
40 changes: 21 additions & 19 deletions compiler/circle-resizer/include/CircleResizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,30 +23,32 @@
#include <memory>
#include <vector>

namespace luci {
namespace luci
{
class Module;
}

namespace circle_resizer
{
class CircleResizer {
public:
explicit CircleResizer(const std::string& model_path);
// to satisfy forward declaration + unique_ptr
~CircleResizer();

public:
void resize_model(const std::vector<Shape>& shapes);
void save_model(const std::string& output_path) const;

public:
std::vector<Shape> input_shapes() const;
std::vector<Shape> output_shapes() const;

private:
std::string _model_path;
std::unique_ptr<luci::Module> _module;
};
class CircleResizer
{
public:
explicit CircleResizer(const std::string &model_path);
// to satisfy forward declaration + unique_ptr
~CircleResizer();

public:
void resize_model(const std::vector<Shape> &shapes);
void save_model(const std::string &output_path) const;

public:
std::vector<Shape> input_shapes() const;
std::vector<Shape> output_shapes() const;

private:
std::string _model_path;
std::unique_ptr<luci::Module> _module;
};
} // namespace circle_resizer

#endif // __CIRCLE_RESIZER_H__
25 changes: 13 additions & 12 deletions compiler/circle-resizer/include/Shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,22 @@

namespace circle_resizer
{
class Dim {
public:
explicit Dim(int32_t dim);
class Dim
{
public:
explicit Dim(int32_t dim);

public:
bool is_dynamic();
int32_t value() const;
bool operator==(const Dim& rhs) const;
public:
bool is_dynamic();
int32_t value() const;
bool operator==(const Dim &rhs) const;

private:
// Note that in the future, we might need to support dimension with lower and upper bounds
int32_t _dim_value;
};
private:
// Note that in the future, we might need to support dimension with lower and upper bounds
int32_t _dim_value;
};

using Shape = std::vector<Dim>;
using Shape = std::vector<Dim>;
} // namespace circle_resizer

#endif // __CIRCLE_RESIZER_SHAPE_H__
7 changes: 4 additions & 3 deletions compiler/circle-resizer/include/ShapeParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@
#include <string>
#include <vector>

namespace circle_resizer {
std::vector<Shape> parse_shapes(std::string shapes_str);
} // circle_resizer
namespace circle_resizer
{
std::vector<Shape> parse_shapes(std::string shapes_str);
} // namespace circle_resizer

#endif // __CIRCLE_RESIZER_SHAPE_PARSER_H__
66 changes: 35 additions & 31 deletions compiler/circle-resizer/src/CircleResizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,42 +36,47 @@

using namespace circle_resizer;

namespace {
std::vector<uint8_t> read_model(const std::string& model_path) {
std::ifstream file_stream(model_path, std::ios::in | std::ios::binary | std::ifstream::ate);
if (!file_stream.is_open()) {
throw std::runtime_error("Failed to open file: " + model_path);
}
namespace
{
std::vector<uint8_t> read_model(const std::string &model_path)
{
std::ifstream file_stream(model_path, std::ios::in | std::ios::binary | std::ifstream::ate);
if (!file_stream.is_open())
{
throw std::runtime_error("Failed to open file: " + model_path);
}

std::streamsize size = file_stream.tellg();
file_stream.seekg(0, std::ios::beg);
std::streamsize size = file_stream.tellg();
file_stream.seekg(0, std::ios::beg);

std::vector<uint8_t> buffer(size);
if (!file_stream.read(reinterpret_cast<char*>(buffer.data()), size)) {
throw std::runtime_error("Failed to read file: " + model_path);
}
std::vector<uint8_t> buffer(size);
if (!file_stream.read(reinterpret_cast<char *>(buffer.data()), size))
{
throw std::runtime_error("Failed to read file: " + model_path);
}

return buffer;
return buffer;
}

void replace_tensor_shape(::flatbuffers::Vector<int32_t>* tensor_shape, const Shape& new_shape)
void replace_tensor_shape(::flatbuffers::Vector<int32_t> *tensor_shape, const Shape &new_shape)
{
const auto shape_size = tensor_shape->size();
if(shape_size != new_shape.size())
if (shape_size != new_shape.size())
{
throw std::runtime_error("Provided shape size: " + std::to_string(new_shape.size()) + " is different from expected: " + std::to_string(shape_size));
throw std::runtime_error("Provided shape size: " + std::to_string(new_shape.size()) +
" is different from expected: " + std::to_string(shape_size));
}
for (uint32_t dim_idx = 0; dim_idx < shape_size; ++dim_idx)
{
tensor_shape->Mutate(dim_idx, new_shape[dim_idx].value());
}
}

template<typename NodeType>
std::vector<Shape> extract_shapes(const std::vector<loco::Node *>& nodes)
template <typename NodeType>
std::vector<Shape> extract_shapes(const std::vector<loco::Node *> &nodes)
{
std::vector<Shape> shapes;
for(const auto& loco_node : nodes)
for (const auto &loco_node : nodes)
{
shapes.push_back(Shape{});
const auto circle_node = loco::must_cast<const NodeType *>(loco_node);
Expand All @@ -86,34 +91,33 @@ std::vector<Shape> extract_shapes(const std::vector<loco::Node *>& nodes)

} // namespace

CircleResizer::CircleResizer(const std::string& model_path)
: _model_path{model_path}
{
}
CircleResizer::CircleResizer(const std::string &model_path) : _model_path{model_path} {}

void CircleResizer::resize_model(const std::vector<Shape>& shapes)
void CircleResizer::resize_model(const std::vector<Shape> &shapes)
{
auto model_buffer = read_model(_model_path);
auto model = circle::GetMutableModel(model_buffer.data());
if (!model)
{
throw std::runtime_error("Incorrect model format");
throw std::runtime_error("Incorrect model format");
}
auto subgraphs = model->mutable_subgraphs();
if (!subgraphs || subgraphs->size() != 1)
{
throw std::runtime_error("Many subgraphs are not supported");
throw std::runtime_error("Many subgraphs are not supported");
}
auto subgraph = subgraphs->GetMutableObject(0);
const auto inputs_number = subgraph->inputs()->size();
if(!inputs_number == shapes.size())
if (!inputs_number == shapes.size())
{
throw std::runtime_error("Expected input shapes: " + std::to_string(inputs_number) + " while provided: " + std::to_string(shapes.size()));
throw std::runtime_error("Expected input shapes: " + std::to_string(inputs_number) +
" while provided: " + std::to_string(shapes.size()));
}
for(int in_idx = 0; in_idx < inputs_number; ++in_idx)
for (int in_idx = 0; in_idx < inputs_number; ++in_idx)
{
const auto in_tensor_idx = subgraph->inputs()->Get(in_idx);
auto input_shape = subgraph->mutable_tensors()->GetMutableObject(in_tensor_idx)->mutable_shape();
auto input_shape =
subgraph->mutable_tensors()->GetMutableObject(in_tensor_idx)->mutable_shape();
replace_tensor_shape(input_shape, shapes[in_idx]);
}

Expand All @@ -137,7 +141,7 @@ void CircleResizer::resize_model(const std::vector<Shape>& shapes)
phase_runner.run(phase);
}

void CircleResizer::save_model(const std::string& output_path) const
void CircleResizer::save_model(const std::string &output_path) const
{
luci::CircleExporter exporter;
luci::CircleFileExpContract contract(_module.get(), output_path);
Expand Down
18 changes: 8 additions & 10 deletions compiler/circle-resizer/src/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@

using namespace circle_resizer;

Dim::Dim(int32_t dim) : _dim_value{dim} {
if (dim < -1) {
throw std::runtime_error("Invalid value of dimension: " + dim);
}
Dim::Dim(int32_t dim) : _dim_value{dim}
{
if (dim < -1)
{
throw std::runtime_error("Invalid value of dimension: " + dim);
}
}

bool Dim::is_dynamic() { return _dim_value == -1; }

int32_t Dim::value() const {
return _dim_value;
}
int32_t Dim::value() const { return _dim_value; }

bool Dim::operator==(const Dim& rhs) const {
return value() == rhs.value();
}
bool Dim::operator==(const Dim &rhs) const { return value() == rhs.value(); }
91 changes: 49 additions & 42 deletions compiler/circle-resizer/src/ShapeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,55 +22,62 @@

using namespace circle_resizer;

namespace {
bool is_blank(const std::string& s)
{
return !s.empty() && std::find_if(s.begin(),
s.end(), [](unsigned char c) { return !std::isblank(c); }) == s.end();
}
namespace
{
bool is_blank(const std::string &s)
{
return !s.empty() && std::find_if(s.begin(), s.end(),
[](unsigned char c) { return !std::isblank(c); }) == s.end();
}

Shape parse_shape(std::string shape_str) {
Shape result_shape;
std::stringstream shape_stream(shape_str);
std::string token;
try
{
while (std::getline(shape_stream, token, ','))
{
result_shape.push_back(Dim{std::stoi(token)});
}
}
catch (...)
{
throw std::invalid_argument("Error during shape processing: " + shape_str);
}
if(result_shape.empty()) {
throw std::invalid_argument("No shapes found in input string: " + shape_str);
}
return result_shape;
Shape parse_shape(std::string shape_str)
{
Shape result_shape;
std::stringstream shape_stream(shape_str);
std::string token;
try
{
while (std::getline(shape_stream, token, ','))
{
result_shape.push_back(Dim{std::stoi(token)});
}
}
catch (...)
{
throw std::invalid_argument("Error during shape processing: " + shape_str);
}
if (result_shape.empty())
{
throw std::invalid_argument("No shapes found in input string: " + shape_str);
}
return result_shape;
}
} // namespace

std::vector<Shape> circle_resizer::parse_shapes(std::string shapes_str)
{
std::vector<Shape> result_shapes;
std::stringstream shapes_stream(shapes_str);
std::string token;
size_t begin_pos = 0, end_pos=0;
while ( (begin_pos = shapes_str.find("[")) != std::string::npos && (end_pos = shapes_str.find("]")) != std::string::npos ) {
token = shapes_str.substr(begin_pos + 1, end_pos);
result_shapes.push_back(parse_shape(token));
shapes_str.erase(0, end_pos + 1);
}
std::vector<Shape> result_shapes;
std::stringstream shapes_stream(shapes_str);
std::string token;
size_t begin_pos = 0, end_pos = 0;
while ((begin_pos = shapes_str.find("[")) != std::string::npos &&
(end_pos = shapes_str.find("]")) != std::string::npos)
{
token = shapes_str.substr(begin_pos + 1, end_pos);
result_shapes.push_back(parse_shape(token));
shapes_str.erase(0, end_pos + 1);
}

if(result_shapes.empty()) {
throw std::invalid_argument("No shapes found in input string: " + shapes_str);
}
if (result_shapes.empty())
{
throw std::invalid_argument("No shapes found in input string: " + shapes_str);
}

// the rest of input not processed by loop above cannot be processed properly
if(shapes_str.size() > 0 && !is_blank(shapes_str)) {
throw std::invalid_argument("The part of input shape: " + shapes_str + " cannot be processed");
}
// the rest of input not processed by loop above cannot be processed properly
if (shapes_str.size() > 0 && !is_blank(shapes_str))
{
throw std::invalid_argument("The part of input shape: " + shapes_str + " cannot be processed");
}

return result_shapes;
return result_shapes;
}
3 changes: 2 additions & 1 deletion compiler/circle-resizer/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

using namespace circle_resizer;

int main(int argc, char *argv[]) {
int main(int argc, char *argv[])
{
CircleResizer resizer(argv[1]);
resizer.resize_model({Shape{Dim{1}, Dim{3}}}); // experiment
resizer.save_model(argv[2]);
Expand Down
Loading

0 comments on commit 4361786

Please sign in to comment.