19
19
#include < variant>
20
20
#include < vector>
21
21
22
+ #include " absl/algorithm/container.h"
22
23
#include " absl/container/btree_set.h"
24
+ #include " absl/container/flat_hash_set.h"
23
25
#include " absl/log/check.h"
24
26
#include " absl/log/log.h"
25
27
#include " absl/status/statusor.h"
@@ -63,6 +65,91 @@ absl::StatusOr<bool> ModernizeNextValues(Proc* proc) {
63
65
return proc->GetStateElementCount () > 0 ;
64
66
}
65
67
68
+ class StateReadPredicateRemover : public Proc ::StateElementTransformer {
69
+ public:
70
+ ~StateReadPredicateRemover () override = default ;
71
+
72
+ absl::StatusOr<std::optional<Node*>> TransformReadPredicate (
73
+ Proc* proc, StateRead* old_state_read) override {
74
+ return std::nullopt;
75
+ }
76
+ };
77
+
78
+ // Ensure that `state_read` is either unpredicated or has a predicate that is
79
+ // true whenever any of its corresponding `next_value`s are active.
80
+ absl::StatusOr<bool > LegalizeStateReadPredicate (
81
+ Proc* proc, StateElement* state_element,
82
+ const SchedulingPassOptions& options) {
83
+ StateRead* state_read = proc->GetStateRead (state_element);
84
+ const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
85
+ proc->next_values (state_read);
86
+ if (!state_read->predicate ().has_value () || next_values.empty ()) {
87
+ // No predicate; nothing to do.
88
+ return false ;
89
+ }
90
+
91
+ if (absl::c_any_of (next_values, [](const Next* next) {
92
+ return !next->predicate ().has_value ();
93
+ })) {
94
+ StateReadPredicateRemover predicate_remover;
95
+ XLS_RETURN_IF_ERROR (proc->TransformStateElement (
96
+ state_read,
97
+ state_read->state_element ()->initial_value (),
98
+ predicate_remover)
99
+ .status ());
100
+ return true ;
101
+ }
102
+
103
+ std::vector<Node*> predicates;
104
+ absl::flat_hash_set<Node*> predicates_set;
105
+ predicates.reserve (1 + next_values.size ());
106
+ predicates_set.reserve (next_values.size ());
107
+ for (Next* next : next_values) {
108
+ CHECK (next->predicate ().has_value ());
109
+ predicates.push_back (*next->predicate ());
110
+ predicates_set.insert (*next->predicate ());
111
+ }
112
+
113
+ Node* state_read_predicate = *state_read->predicate ();
114
+ if (state_read_predicate->op () == Op::kOr &&
115
+ predicates_set ==
116
+ absl::flat_hash_set<Node*>(predicates.begin (), predicates.end ())) {
117
+ // The predicate is already trivially correct; nothing to do.
118
+ return false ;
119
+ }
120
+ if (predicates_set.size () == 1 &&
121
+ predicates.front () == state_read_predicate) {
122
+ // The predicate is already trivially correct; nothing to do.
123
+ return false ;
124
+ }
125
+
126
+ predicates.insert (predicates.begin (), state_read_predicate);
127
+ XLS_ASSIGN_OR_RETURN (
128
+ Node * new_predicate,
129
+ NaryOrIfNeeded (proc, predicates, /* name=*/ " " , state_read->loc ()));
130
+ XLS_RETURN_IF_ERROR (state_read->ReplaceOperandNumber (
131
+ *state_read->predicate_operand_number (), new_predicate));
132
+ return true ;
133
+ }
134
+
135
+ absl::StatusOr<bool > LegalizeStateReadPredicates (
136
+ Proc* proc, const SchedulingPassOptions& options) {
137
+ bool changed = false ;
138
+
139
+ for (StateElement* state_element : proc->StateElements ()) {
140
+ XLS_ASSIGN_OR_RETURN (
141
+ bool state_read_changed,
142
+ LegalizeStateReadPredicate (proc, state_element, options));
143
+ if (state_read_changed) {
144
+ VLOG (4 ) << " Generalized read predicate for state element: "
145
+ << state_element->name ();
146
+ changed = true ;
147
+ }
148
+ }
149
+
150
+ return changed;
151
+ }
152
+
66
153
absl::StatusOr<bool > AddDefaultNextValue (Proc* proc,
67
154
StateElement* state_element,
68
155
const SchedulingPassOptions& options) {
@@ -83,7 +170,7 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
83
170
XLS_RETURN_IF_ERROR (proc->MakeNodeWithName <Next>(
84
171
state_read->loc (), /* state_read=*/ state_read,
85
172
/* value=*/ state_read,
86
- /* predicate=*/ std::nullopt ,
173
+ /* predicate=*/ state_read-> predicate () ,
87
174
absl::StrCat (state_element->name (), " _default" ))
88
175
.status ());
89
176
return true ;
@@ -101,6 +188,22 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
101
188
continue ;
102
189
}
103
190
191
+ if (state_read->predicate ().has_value () && predicate->OpIn ({Op::kAnd }) &&
192
+ predicate->operands ().size () == 2 ) {
193
+ // Check to see if this is just an `and` with the state read predicate. If
194
+ // so, take the other operand & see if it's a not/nor of the other
195
+ // conditions.
196
+ if (predicate->operand (0 ) == *state_read->predicate ()) {
197
+ predicate = predicate->operand (1 );
198
+ } else if (predicate->operand (1 ) == *state_read->predicate ()) {
199
+ predicate = predicate->operand (0 );
200
+ } else {
201
+ // It's not, so we can't trivially recognize it as being of the right
202
+ // form.
203
+ continue ;
204
+ }
205
+ }
206
+
104
207
if (!predicate->OpIn ({Op::kNot , Op::kNor })) {
105
208
continue ;
106
209
}
@@ -144,13 +247,21 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
144
247
// Explicitly mark the param as unchanged when no other `next_value` node is
145
248
// active.
146
249
XLS_ASSIGN_OR_RETURN (
147
- Node * all_predicates_false ,
250
+ Node * default_predicate ,
148
251
NaryNorIfNeeded (proc, std::vector (predicates.begin (), predicates.end ()),
149
252
/* name=*/ " " , state_read->loc ()));
253
+ if (state_read->predicate ().has_value ()) {
254
+ XLS_ASSIGN_OR_RETURN (
255
+ default_predicate,
256
+ proc->MakeNode <NaryOp>(
257
+ state_read->loc (),
258
+ absl::MakeConstSpan ({*state_read->predicate (), default_predicate}),
259
+ Op::kAnd ));
260
+ }
150
261
XLS_RETURN_IF_ERROR (proc->MakeNodeWithName <Next>(
151
262
state_read->loc (), /* state_read=*/ state_read,
152
263
/* value=*/ state_read,
153
- /* predicate=*/ all_predicates_false ,
264
+ /* predicate=*/ default_predicate ,
154
265
absl::StrCat (state_element->name (), " _default" ))
155
266
.status ());
156
267
return true ;
@@ -191,7 +302,21 @@ absl::StatusOr<bool> ProcStateLegalizationPass::RunOnFunctionBaseInternal(
191
302
return ModernizeNextValues (proc);
192
303
}
193
304
194
- return AddDefaultNextValues (proc, options);
305
+ bool changed = false ;
306
+
307
+ XLS_ASSIGN_OR_RETURN (bool read_predicates_changed,
308
+ LegalizeStateReadPredicates (proc, options));
309
+ if (read_predicates_changed) {
310
+ changed = true ;
311
+ }
312
+
313
+ XLS_ASSIGN_OR_RETURN (bool default_nexts_added,
314
+ AddDefaultNextValues (proc, options));
315
+ if (default_nexts_added) {
316
+ changed = true ;
317
+ }
318
+
319
+ return changed;
195
320
}
196
321
197
322
} // namespace xls
0 commit comments