Skip to content

Commit b959a4e

Browse files
ericastorcopybara-github
authored andcommitted
[XLS] Support StateRead predicates in state legalization
PiperOrigin-RevId: 709185318
1 parent 4dc4230 commit b959a4e

File tree

3 files changed

+227
-4
lines changed

3 files changed

+227
-4
lines changed

xls/scheduling/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,9 @@ cc_library(
351351
"//xls/ir:op",
352352
"//xls/ir:state_element",
353353
"//xls/solvers:z3_ir_translator",
354+
"@com_google_absl//absl/algorithm:container",
354355
"@com_google_absl//absl/container:btree",
356+
"@com_google_absl//absl/container:flat_hash_set",
355357
"@com_google_absl//absl/log",
356358
"@com_google_absl//absl/log:check",
357359
"@com_google_absl//absl/status:statusor",

xls/scheduling/proc_state_legalization_pass.cc

Lines changed: 129 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@
1919
#include <variant>
2020
#include <vector>
2121

22+
#include "absl/algorithm/container.h"
2223
#include "absl/container/btree_set.h"
24+
#include "absl/container/flat_hash_set.h"
2325
#include "absl/log/check.h"
2426
#include "absl/log/log.h"
2527
#include "absl/status/statusor.h"
@@ -63,6 +65,91 @@ absl::StatusOr<bool> ModernizeNextValues(Proc* proc) {
6365
return proc->GetStateElementCount() > 0;
6466
}
6567

68+
class StateReadPredicateRemover : public Proc::StateElementTransformer {
69+
public:
70+
~StateReadPredicateRemover() override = default;
71+
72+
absl::StatusOr<std::optional<Node*>> TransformReadPredicate(
73+
Proc* proc, StateRead* old_state_read) override {
74+
return std::nullopt;
75+
}
76+
};
77+
78+
// Ensure that `state_read` is either unpredicated or has a predicate that is
79+
// true whenever any of its corresponding `next_value`s are active.
80+
absl::StatusOr<bool> LegalizeStateReadPredicate(
81+
Proc* proc, StateElement* state_element,
82+
const SchedulingPassOptions& options) {
83+
StateRead* state_read = proc->GetStateRead(state_element);
84+
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
85+
proc->next_values(state_read);
86+
if (!state_read->predicate().has_value() || next_values.empty()) {
87+
// No predicate; nothing to do.
88+
return false;
89+
}
90+
91+
if (absl::c_any_of(next_values, [](const Next* next) {
92+
return !next->predicate().has_value();
93+
})) {
94+
StateReadPredicateRemover predicate_remover;
95+
XLS_RETURN_IF_ERROR(proc->TransformStateElement(
96+
state_read,
97+
state_read->state_element()->initial_value(),
98+
predicate_remover)
99+
.status());
100+
return true;
101+
}
102+
103+
std::vector<Node*> predicates;
104+
absl::flat_hash_set<Node*> predicates_set;
105+
predicates.reserve(1 + next_values.size());
106+
predicates_set.reserve(next_values.size());
107+
for (Next* next : next_values) {
108+
CHECK(next->predicate().has_value());
109+
predicates.push_back(*next->predicate());
110+
predicates_set.insert(*next->predicate());
111+
}
112+
113+
Node* state_read_predicate = *state_read->predicate();
114+
if (state_read_predicate->op() == Op::kOr &&
115+
predicates_set ==
116+
absl::flat_hash_set<Node*>(predicates.begin(), predicates.end())) {
117+
// The predicate is already trivially correct; nothing to do.
118+
return false;
119+
}
120+
if (predicates_set.size() == 1 &&
121+
predicates.front() == state_read_predicate) {
122+
// The predicate is already trivially correct; nothing to do.
123+
return false;
124+
}
125+
126+
predicates.insert(predicates.begin(), state_read_predicate);
127+
XLS_ASSIGN_OR_RETURN(
128+
Node * new_predicate,
129+
NaryOrIfNeeded(proc, predicates, /*name=*/"", state_read->loc()));
130+
XLS_RETURN_IF_ERROR(state_read->ReplaceOperandNumber(
131+
*state_read->predicate_operand_number(), new_predicate));
132+
return true;
133+
}
134+
135+
absl::StatusOr<bool> LegalizeStateReadPredicates(
136+
Proc* proc, const SchedulingPassOptions& options) {
137+
bool changed = false;
138+
139+
for (StateElement* state_element : proc->StateElements()) {
140+
XLS_ASSIGN_OR_RETURN(
141+
bool state_read_changed,
142+
LegalizeStateReadPredicate(proc, state_element, options));
143+
if (state_read_changed) {
144+
VLOG(4) << "Generalized read predicate for state element: "
145+
<< state_element->name();
146+
changed = true;
147+
}
148+
}
149+
150+
return changed;
151+
}
152+
66153
absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
67154
StateElement* state_element,
68155
const SchedulingPassOptions& options) {
@@ -83,7 +170,7 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
83170
XLS_RETURN_IF_ERROR(proc->MakeNodeWithName<Next>(
84171
state_read->loc(), /*state_read=*/state_read,
85172
/*value=*/state_read,
86-
/*predicate=*/std::nullopt,
173+
/*predicate=*/state_read->predicate(),
87174
absl::StrCat(state_element->name(), "_default"))
88175
.status());
89176
return true;
@@ -101,6 +188,22 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
101188
continue;
102189
}
103190

191+
if (state_read->predicate().has_value() && predicate->OpIn({Op::kAnd}) &&
192+
predicate->operands().size() == 2) {
193+
// Check to see if this is just an `and` with the state read predicate. If
194+
// so, take the other operand & see if it's a not/nor of the other
195+
// conditions.
196+
if (predicate->operand(0) == *state_read->predicate()) {
197+
predicate = predicate->operand(1);
198+
} else if (predicate->operand(1) == *state_read->predicate()) {
199+
predicate = predicate->operand(0);
200+
} else {
201+
// It's not, so we can't trivially recognize it as being of the right
202+
// form.
203+
continue;
204+
}
205+
}
206+
104207
if (!predicate->OpIn({Op::kNot, Op::kNor})) {
105208
continue;
106209
}
@@ -144,13 +247,21 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
144247
// Explicitly mark the param as unchanged when no other `next_value` node is
145248
// active.
146249
XLS_ASSIGN_OR_RETURN(
147-
Node * all_predicates_false,
250+
Node * default_predicate,
148251
NaryNorIfNeeded(proc, std::vector(predicates.begin(), predicates.end()),
149252
/*name=*/"", state_read->loc()));
253+
if (state_read->predicate().has_value()) {
254+
XLS_ASSIGN_OR_RETURN(
255+
default_predicate,
256+
proc->MakeNode<NaryOp>(
257+
state_read->loc(),
258+
absl::MakeConstSpan({*state_read->predicate(), default_predicate}),
259+
Op::kAnd));
260+
}
150261
XLS_RETURN_IF_ERROR(proc->MakeNodeWithName<Next>(
151262
state_read->loc(), /*state_read=*/state_read,
152263
/*value=*/state_read,
153-
/*predicate=*/all_predicates_false,
264+
/*predicate=*/default_predicate,
154265
absl::StrCat(state_element->name(), "_default"))
155266
.status());
156267
return true;
@@ -191,7 +302,21 @@ absl::StatusOr<bool> ProcStateLegalizationPass::RunOnFunctionBaseInternal(
191302
return ModernizeNextValues(proc);
192303
}
193304

194-
return AddDefaultNextValues(proc, options);
305+
bool changed = false;
306+
307+
XLS_ASSIGN_OR_RETURN(bool read_predicates_changed,
308+
LegalizeStateReadPredicates(proc, options));
309+
if (read_predicates_changed) {
310+
changed = true;
311+
}
312+
313+
XLS_ASSIGN_OR_RETURN(bool default_nexts_added,
314+
AddDefaultNextValues(proc, options));
315+
if (default_nexts_added) {
316+
changed = true;
317+
}
318+
319+
return changed;
195320
}
196321

197322
} // namespace xls

xls/scheduling/proc_state_legalization_pass_test.cc

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include "xls/scheduling/proc_state_legalization_pass.h"
1616

17+
#include <optional>
18+
1719
#include "gmock/gmock.h"
1820
#include "gtest/gtest.h"
1921
#include "absl/status/status_matchers.h"
@@ -269,5 +271,99 @@ TEST_F(ProcStateLegalizationPassTest,
269271
m::Nor(positive_predicate.node(), negative_predicate.node()))));
270272
}
271273

274+
TEST_F(ProcStateLegalizationPassTest, ProcWithPredicatedStateRead) {
275+
auto p = CreatePackage();
276+
ProcBuilder pb("p", p.get());
277+
BValue x = pb.StateElement("x", Value(UBits(0, 32)));
278+
BValue x_even =
279+
pb.Eq(pb.UMod(x, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32)));
280+
BValue x_multiple_of_3 =
281+
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
282+
BValue y = pb.StateElement("y", Value(UBits(0, 32)),
283+
/*read_predicate=*/x_even);
284+
pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32))));
285+
pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3);
286+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
287+
288+
ASSERT_THAT(Run(proc), IsOkAndHolds(true));
289+
290+
EXPECT_EQ(proc->GetStateRead(*proc->GetStateElement("x"))->predicate(),
291+
std::nullopt);
292+
EXPECT_THAT(
293+
proc->GetStateRead(*proc->GetStateElement("y"))->predicate(),
294+
Optional(m::Or(
295+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)),
296+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)))));
297+
EXPECT_THAT(
298+
proc->next_values(proc->GetStateRead(*proc->GetStateElement("y"))),
299+
UnorderedElementsAre(
300+
m::Next(
301+
m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)),
302+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))),
303+
m::Next(m::StateRead("y"), m::StateRead("y"),
304+
m::And(m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)),
305+
m::Literal(0)),
306+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
307+
m::Literal(0))),
308+
m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
309+
m::Literal(0)))))));
310+
}
311+
312+
TEST_F(ProcStateLegalizationPassTest,
313+
ProcWithCorrectlyPredicatedStateReadAndNoDefaultNextNeeded) {
314+
auto p = CreatePackage();
315+
ProcBuilder pb("p", p.get());
316+
BValue x = pb.StateElement("x", Value(UBits(0, 32)));
317+
BValue x_multiple_of_3 =
318+
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
319+
BValue x_not_multiple_of_3 = pb.Not(x_multiple_of_3);
320+
BValue disjunction = pb.Or(x_multiple_of_3, x_not_multiple_of_3);
321+
BValue y = pb.StateElement("y", Value(UBits(0, 32)),
322+
/*read_predicate=*/disjunction);
323+
pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32))));
324+
pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3);
325+
pb.Next(y, y, x_not_multiple_of_3);
326+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
327+
328+
ASSERT_THAT(Run(proc), IsOkAndHolds(false));
329+
}
330+
331+
TEST_F(ProcStateLegalizationPassTest,
332+
ProcWithPredicatedStateReadAndNoDefaultNextNeeded) {
333+
auto p = CreatePackage();
334+
ProcBuilder pb("p", p.get());
335+
BValue x = pb.StateElement("x", Value(UBits(0, 32)));
336+
BValue x_even =
337+
pb.Eq(pb.UMod(x, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32)));
338+
BValue y = pb.StateElement("y", Value(UBits(0, 32)),
339+
/*read_predicate=*/x_even);
340+
BValue x_multiple_of_3 =
341+
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
342+
BValue x_not_multiple_of_3 = pb.Not(x_multiple_of_3);
343+
pb.Next(x, pb.Add(x, pb.Literal(UBits(1, 32))));
344+
pb.Next(y, pb.Add(y, pb.Literal(UBits(1, 32))), x_multiple_of_3);
345+
pb.Next(y, y, x_not_multiple_of_3);
346+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
347+
348+
ASSERT_THAT(Run(proc), IsOkAndHolds(true));
349+
350+
EXPECT_THAT(
351+
proc->GetStateRead(*proc->GetStateElement("y"))->predicate(),
352+
Optional(
353+
m::Or(m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0)),
354+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0)),
355+
m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
356+
m::Literal(0))))));
357+
EXPECT_THAT(
358+
proc->next_values(proc->GetStateRead(*proc->GetStateElement("y"))),
359+
UnorderedElementsAre(
360+
m::Next(
361+
m::StateRead("y"), m::Add(m::StateRead("y"), m::Literal(1)),
362+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0))),
363+
m::Next(m::StateRead("y"), m::StateRead("y"),
364+
m::Not(m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)),
365+
m::Literal(0))))));
366+
}
367+
272368
} // namespace
273369
} // namespace xls

0 commit comments

Comments
 (0)