From 16602fc779cdd054815fb212841ca006fdddec7f Mon Sep 17 00:00:00 2001 From: Eric Astor Date: Wed, 18 Dec 2024 15:43:01 -0800 Subject: [PATCH] [XLS] Make sure the IR minimizer respects state-read arguments of next_value nodes Before this, the minimizer could attempt to replace these arguments with a literal and crash. PiperOrigin-RevId: 707690916 --- xls/dev_tools/ir_minimizer_main.cc | 12 ++++++------ xls/dev_tools/ir_minimizer_main_test.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/xls/dev_tools/ir_minimizer_main.cc b/xls/dev_tools/ir_minimizer_main.cc index 47d8296845..52bafbeda4 100644 --- a/xls/dev_tools/ir_minimizer_main.cc +++ b/xls/dev_tools/ir_minimizer_main.cc @@ -754,23 +754,23 @@ absl::StatusOr SimplifyReturnValue( return SimplificationResult::kDidNotChange; } -// Replace all uses of 'target' with a literal 'v' except for the 'param' -// argument of next nodes (which need to stay as params). +// Replace all uses of 'target' with a literal 'v' except for the 'state_read' +// argument of next nodes (which need to stay as StateReads). template absl::StatusOr SafeReplaceUsesWithNew(Node* target, Args... v) { auto is_state_read_of_next = [&](Node* n) { return n->Is() && n->As()->state_read() == target; }; - if (!target->Is() || + if (!target->Is() || !absl::c_any_of(target->users(), is_state_read_of_next)) { return target->ReplaceUsesWithNew(std::forward(v)...); } - std::vector param_users; - absl::c_copy_if(target->users(), std::back_inserter(param_users), + std::vector state_read_users; + absl::c_copy_if(target->users(), std::back_inserter(state_read_users), is_state_read_of_next); XLS_ASSIGN_OR_RETURN( auto result, target->ReplaceUsesWithNew(std::forward(v)...)); - for (Node* n : param_users) { + for (Node* n : state_read_users) { XLS_RETURN_IF_ERROR( n->As()->ReplaceOperandNumber(Next::kStateReadOperand, target)); } diff --git a/xls/dev_tools/ir_minimizer_main_test.py b/xls/dev_tools/ir_minimizer_main_test.py index ed82ab3af5..955d5fa92d 100644 --- a/xls/dev_tools/ir_minimizer_main_test.py +++ b/xls/dev_tools/ir_minimizer_main_test.py @@ -173,14 +173,14 @@ tuple_index.21: token = tuple_index(receive.20, index=0, id=21) send.25: token = send(tok__5, add.24, channel=multi_proc__bytes_result, id=25) tuple.26: () = tuple(id=26, pos=[(0,72,15)]) - next (tuple.26) + next_state: () = next_value(param=__state, value=tuple.26, id=33) } proc __multi_proc__proc_ten__proc_double_0_next(__state: (), init={()}) { tok: token = after_all(id=30) receive.31: (token, bits[32]) = receive(tok, channel=multi_proc__send_double_pipe, id=31) v: bits[32] = tuple_index(receive.31, index=1, id=34, pos=[(0,29,18)]) - tok__1: token = tuple_index(receive.31, index=0, id=33, pos=[(0,29,13)]) + tok__1: token = tuple_index(receive.31, index=0, id=62, pos=[(0,29,13)]) invoke.35: bits[32] = invoke(v, to_apply=__multi_proc__double_it, id=35, pos=[(0,30,41)]) __token: token = literal(value=token, id=27) literal.29: bits[1] = literal(value=1, id=29)