Skip to content

Commit

Permalink
[XLS] Make sure the IR minimizer respects state-read arguments of nex…
Browse files Browse the repository at this point in the history
…t_value nodes

Before this, the minimizer could attempt to replace these arguments with a literal and crash.

PiperOrigin-RevId: 707690916
  • Loading branch information
ericastor authored and copybara-github committed Dec 18, 2024
1 parent 9016ba0 commit 16602fc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
12 changes: 6 additions & 6 deletions xls/dev_tools/ir_minimizer_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -754,23 +754,23 @@ absl::StatusOr<SimplificationResult> SimplifyReturnValue(
return SimplificationResult::kDidNotChange;
}

// Replace all uses of 'target' with a literal 'v' except for the 'param'
// argument of next nodes (which need to stay as params).
// Replace all uses of 'target' with a literal 'v' except for the 'state_read'
// argument of next nodes (which need to stay as StateReads).
template <typename NodeT, typename... Args>
absl::StatusOr<Node*> SafeReplaceUsesWithNew(Node* target, Args... v) {
auto is_state_read_of_next = [&](Node* n) {
return n->Is<Next>() && n->As<Next>()->state_read() == target;
};
if (!target->Is<Param>() ||
if (!target->Is<StateRead>() ||
!absl::c_any_of(target->users(), is_state_read_of_next)) {
return target->ReplaceUsesWithNew<NodeT>(std::forward<Args>(v)...);
}
std::vector<Node*> param_users;
absl::c_copy_if(target->users(), std::back_inserter(param_users),
std::vector<Node*> state_read_users;
absl::c_copy_if(target->users(), std::back_inserter(state_read_users),
is_state_read_of_next);
XLS_ASSIGN_OR_RETURN(
auto result, target->ReplaceUsesWithNew<NodeT>(std::forward<Args>(v)...));
for (Node* n : param_users) {
for (Node* n : state_read_users) {
XLS_RETURN_IF_ERROR(
n->As<Next>()->ReplaceOperandNumber(Next::kStateReadOperand, target));
}
Expand Down
4 changes: 2 additions & 2 deletions xls/dev_tools/ir_minimizer_main_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,14 +173,14 @@
tuple_index.21: token = tuple_index(receive.20, index=0, id=21)
send.25: token = send(tok__5, add.24, channel=multi_proc__bytes_result, id=25)
tuple.26: () = tuple(id=26, pos=[(0,72,15)])
next (tuple.26)
next_state: () = next_value(param=__state, value=tuple.26, id=33)
}
proc __multi_proc__proc_ten__proc_double_0_next(__state: (), init={()}) {
tok: token = after_all(id=30)
receive.31: (token, bits[32]) = receive(tok, channel=multi_proc__send_double_pipe, id=31)
v: bits[32] = tuple_index(receive.31, index=1, id=34, pos=[(0,29,18)])
tok__1: token = tuple_index(receive.31, index=0, id=33, pos=[(0,29,13)])
tok__1: token = tuple_index(receive.31, index=0, id=62, pos=[(0,29,13)])
invoke.35: bits[32] = invoke(v, to_apply=__multi_proc__double_it, id=35, pos=[(0,30,41)])
__token: token = literal(value=token, id=27)
literal.29: bits[1] = literal(value=1, id=29)
Expand Down

0 comments on commit 16602fc

Please sign in to comment.