Skip to content

Commit 27a8ad6

Browse files
committed
fix(frontend): 填充边信息时不能接受输入边不存在
Signed-off-by: YdrMaster <[email protected]>
1 parent 35dc6c8 commit 27a8ad6

File tree

4 files changed

+16
-22
lines changed

4 files changed

+16
-22
lines changed

src/00common/include/common/error_handler.h

+3-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,9 @@ namespace refactor {
3030
std::abort()
3131

3232
#ifndef DISABLE_ASSERT
33-
#define ASSERT(CONDITION, F, ...) \
34-
{ \
35-
if (!(CONDITION)) RUNTIME_ERROR(fmt::format("Assertion: " #F, ##__VA_ARGS__)); \
33+
#define ASSERT(CONDITION, F, ...) \
34+
{ \
35+
if (!(CONDITION)) RUNTIME_ERROR(fmt::format("Assertion: " F, ##__VA_ARGS__)); \
3636
}
3737
#else
3838
#define ASSERT(CONDITION, F)

src/06frontend/src/graph.cc

+9-15
Original file line numberDiff line numberDiff line change
@@ -98,27 +98,21 @@ namespace refactor::frontend {
9898
auto const startTime = high_resolution_clock::now();
9999
// 拓扑遍历
100100
for (auto [nodeIdx, inputs, outputs] : _internal.topology) {
101-
auto unknownEdge = false, inputChanged = false;
102-
for (auto i : inputs) {
103-
auto const &input = _internal.edges[i].tensor;
104-
if (!input) {// 有入边未知
105-
unknownEdge = true;
106-
break;
107-
}
108-
auto checked = edgeChanged[2 * i]; // NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
109-
auto changed = edgeChanged[2 * i + 1];// NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
101+
auto inputChanged = false;
102+
for (auto i : range0_(inputs.size())) {
103+
auto j = inputs[i];
104+
auto const &input = _internal.edges[j].tensor;
105+
ASSERT(input, "The {}th input of \"{}\" is nullptr", i, _internal.nodes[nodeIdx].name);
106+
auto checked = edgeChanged[2 * j]; // NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
107+
auto changed = edgeChanged[2 * j + 1];// NOTICE `std::vector<bool>::operator[]` 产生常引用!!!
110108
if (!checked) {
111109
checked = true;
112-
if (changed = _edgeSnapshot[i] != *input) {
113-
_edgeSnapshot[i] = input->snapshot();
110+
if (changed = _edgeSnapshot[j] != *input) {
111+
_edgeSnapshot[j] = input->snapshot();
114112
}
115113
}
116114
inputChanged |= changed;
117115
}
118-
// 有入边未知,跳过节点
119-
if (unknownEdge) {
120-
continue;
121-
}
122116
if (!inputChanged && std::all_of(outputs.begin(), outputs.end(),
123117
[this](auto i) { return _internal.edges[i].tensor; })) {
124118
// 入边未发生变化,且出边已推导

src/07onnx/src/operators/split.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ namespace refactor::onnx {
1212
numOutputs(numOutputs_) {}
1313

1414
auto Op::build(ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
15-
auto axis = attributes.getOrInsert( "axis", {0}).int_();
16-
auto numOutputs = attributes.getOrInsert( "num_outputs", {0}).int_();
15+
auto axis = attributes.getOrInsert("axis", {0}).int_();
16+
auto numOutputs = attributes.getOrInsert("num_outputs", {0}).int_();
1717
return OpBox(std::make_unique<Op>(axis, numOutputs));
1818
}
1919
auto Op::typeId() -> size_t {
@@ -45,7 +45,7 @@ namespace refactor::onnx {
4545
ans[i] = Tensor::share(input.dataType, input.shape, dependencies);
4646
ans[i]->shape[axis_] = DimExpr(each);
4747
} else {
48-
ASSERT(i == numOutputs - 1, ERROR_MSG("Split error"));
48+
ASSERT(i == numOutputs - 1, "Split error");
4949
ans[i] = Tensor::share(input.dataType, input.shape, dependencies);
5050
ans[i]->shape[axis_] = DimExpr(total);
5151
}

src/07onnx/src/operators/tile.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace refactor::onnx {
2929
return Err(InferError(ERROR_MSG("repeats not support")));
3030
}
3131
EXPECT_VAL(repeats.shape[0], repeatsSize)
32-
ASSERT(repeatsSize == rank, ERROR_MSG("repeats size error"));
32+
ASSERT(repeatsSize == rank, "repeats size error");
3333

3434
auto repeats_ = repeats.data->get<int64_t>();
3535
Shape output(rank, DimExpr(1));

0 commit comments

Comments
 (0)