Skip to content

Commit 2766e34

Browse files
meheffernancopybara-github
authored andcommitted
Add support for new-style procs to dead function elimination pass.
For new-style procs, elaboration trivially determines proc liveness. PiperOrigin-RevId: 597682240
1 parent 27c3660 commit 2766e34

File tree

4 files changed

+175
-54
lines changed

4 files changed

+175
-54
lines changed

xls/passes/BUILD

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,7 @@ cc_library(
422422
"//xls/data_structures:union_find",
423423
"//xls/ir",
424424
"//xls/ir:channel",
425+
"//xls/ir:elaboration",
425426
"//xls/ir:node_util",
426427
"//xls/ir:op",
427428
],
@@ -2081,11 +2082,14 @@ cc_test(
20812082
"//xls/common:xls_gunit",
20822083
"//xls/common:xls_gunit_main",
20832084
"//xls/common/status:matchers",
2085+
"//xls/common/status:status_macros",
20842086
"//xls/ir",
20852087
"//xls/ir:bits",
2088+
"//xls/ir:channel",
20862089
"//xls/ir:function_builder",
20872090
"//xls/ir:ir_matcher",
20882091
"//xls/ir:ir_test_base",
2092+
"//xls/ir:value",
20892093
],
20902094
)
20912095

xls/passes/dfe_pass.cc

Lines changed: 89 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@
1818
#include <memory>
1919
#include <optional>
2020
#include <string>
21+
#include <string_view>
2122
#include <vector>
2223

2324
#include "absl/container/flat_hash_map.h"
2425
#include "absl/container/flat_hash_set.h"
25-
#include "absl/status/status.h"
2626
#include "absl/status/statusor.h"
27-
#include "absl/strings/str_format.h"
2827
#include "xls/common/logging/logging.h"
2928
#include "xls/common/status/status_macros.h"
3029
#include "xls/data_structures/union_find.h"
3130
#include "xls/ir/block.h"
3231
#include "xls/ir/channel.h"
32+
#include "xls/ir/elaboration.h"
3333
#include "xls/ir/function_base.h"
34+
#include "xls/ir/instantiation.h"
3435
#include "xls/ir/node_util.h"
3536
#include "xls/ir/op.h"
3637
#include "xls/ir/package.h"
@@ -79,63 +80,104 @@ void MarkReachedFunctions(FunctionBase* func,
7980
}
8081
}
8182
}
82-
} // namespace
8383

84-
// Starting from the return_value(s), DFS over all nodes. Unvisited
85-
// nodes, or parameters, are dead.
86-
absl::StatusOr<bool> DeadFunctionEliminationPass::RunInternal(
87-
Package* p, const OptimizationPassOptions& options,
88-
PassResults* results) const {
89-
std::optional<FunctionBase*> top = p->GetTop();
90-
if (!top.has_value()) {
91-
return false;
84+
// Data structure describing the liveness of global constructs in a package.
85+
struct FunctionBaseLiveness {
86+
// The live roots of the package. This does not include FunctionBases which
87+
// are live because they are called/instantiated from other FunctionBases.
88+
std::vector<FunctionBase*> live_roots;
89+
90+
// Set of the live global channels. Only set for old-style procs.
91+
absl::flat_hash_set<Channel*> live_global_channels;
92+
};
93+
94+
absl::StatusOr<FunctionBaseLiveness> LivenessFromTopProc(Proc* top) {
95+
if (top->is_new_style_proc()) {
96+
XLS_ASSIGN_OR_RETURN(Elaboration elab, Elaboration::Elaborate(top));
97+
return FunctionBaseLiveness{.live_roots = std::vector<FunctionBase*>(
98+
elab.procs().begin(), elab.procs().end()),
99+
.live_global_channels = {}};
92100
}
93101

94-
// Mapping from proc->channel, where channel is a representative value
95-
// for all the channel names in the UnionFind.
96-
absl::flat_hash_map<Proc*, std::string> representative_channels;
102+
Package* p = top->package();
103+
104+
// Mapping from proc to channel, where channel is a representative value for
105+
// all the channel names in the UnionFind. If the proc uses no channels then
106+
// the value will be nullopt.
107+
absl::flat_hash_map<Proc*, std::optional<std::string_view>>
108+
representative_channels;
97109
representative_channels.reserve(p->procs().size());
98110
// Channels in the same proc will be union'd.
99-
UnionFind<std::string> channel_union;
100-
for (std::unique_ptr<Proc>& proc : p->procs()) {
101-
std::optional<std::string> representative_proc_channel;
111+
UnionFind<std::string_view> channel_union;
112+
for (const std::unique_ptr<Proc>& proc : p->procs()) {
113+
std::optional<std::string_view> representative_proc_channel;
102114
for (Node* node : proc->nodes()) {
103115
if (IsChannelNode(node)) {
104-
std::string channel;
105-
if (node->Is<Send>()) {
106-
channel = node->As<Send>()->channel_name();
107-
} else if (node->Is<Receive>()) {
108-
channel = node->As<Receive>()->channel_name();
109-
} else {
110-
return absl::NotFoundError(absl::StrFormat(
111-
"No channel associated with node %s", node->GetName()));
112-
}
113-
channel_union.Insert(channel);
116+
XLS_ASSIGN_OR_RETURN(Channel * channel, GetChannelUsedByNode(node));
117+
channel_union.Insert(channel->name());
114118
if (representative_proc_channel.has_value()) {
115-
channel_union.Union(representative_proc_channel.value(), channel);
119+
channel_union.Union(representative_proc_channel.value(),
120+
channel->name());
116121
} else {
117-
representative_proc_channel = channel;
118-
representative_channels.insert({proc.get(), channel});
122+
representative_proc_channel = channel->name();
119123
}
120124
}
121125
}
126+
representative_channels[proc.get()] = representative_proc_channel;
122127
}
123128

124-
absl::flat_hash_set<FunctionBase*> reached;
125-
MarkReachedFunctions(top.value(), &reached);
126-
std::optional<std::string> top_proc_representative_channel;
127-
if ((*top)->IsProc()) {
128-
auto itr = representative_channels.find(top.value()->AsProcOrDie());
129-
if (itr != representative_channels.end()) {
130-
top_proc_representative_channel = channel_union.Find(itr->second);
131-
for (auto [proc, representative_channel] : representative_channels) {
132-
if (channel_union.Find(representative_channel) ==
133-
*top_proc_representative_channel) {
134-
MarkReachedFunctions(proc, &reached);
135-
}
129+
FunctionBaseLiveness liveness;
130+
131+
// Add procs to the live set if they are connnected to `top` via channels.
132+
for (const std::unique_ptr<Proc>& proc : p->procs()) {
133+
if (proc.get() == top) {
134+
liveness.live_roots.push_back(proc.get());
135+
continue;
136+
}
137+
if (representative_channels.at(top).has_value() &&
138+
representative_channels.at(proc.get()) &&
139+
channel_union.Find(representative_channels.at(top).value()) ==
140+
channel_union.Find(
141+
representative_channels.at(proc.get()).value())) {
142+
liveness.live_roots.push_back(proc.get());
143+
}
144+
}
145+
146+
// Add channels to the live set if they are connnected to `top`.
147+
if (representative_channels.at(top).has_value()) {
148+
for (Channel* channel : p->channels()) {
149+
if (channel_union.Find(channel->name()) ==
150+
channel_union.Find(representative_channels.at(top).value())) {
151+
liveness.live_global_channels.insert(channel);
136152
}
137153
}
138154
}
155+
return liveness;
156+
}
157+
158+
} // namespace
159+
160+
// Starting from the return_value(s), DFS over all nodes. Unvisited
161+
// nodes, or parameters, are dead.
162+
absl::StatusOr<bool> DeadFunctionEliminationPass::RunInternal(
163+
Package* p, const OptimizationPassOptions& options,
164+
PassResults* results) const {
165+
std::optional<FunctionBase*> top = p->GetTop();
166+
if (!top.has_value()) {
167+
return false;
168+
}
169+
170+
FunctionBaseLiveness liveness;
171+
if ((*top)->IsProc()) {
172+
XLS_ASSIGN_OR_RETURN(liveness, LivenessFromTopProc((*top)->AsProcOrDie()));
173+
} else {
174+
liveness.live_roots = {*top};
175+
}
176+
177+
absl::flat_hash_set<FunctionBase*> reached;
178+
for (FunctionBase* fb : liveness.live_roots) {
179+
MarkReachedFunctions(fb, &reached);
180+
}
139181

140182
// Accumulate a list of FunctionBases to unlink.
141183
bool changed = false;
@@ -147,19 +189,15 @@ absl::StatusOr<bool> DeadFunctionEliminationPass::RunInternal(
147189
}
148190
}
149191

150-
// Find any channels which are only used by now-removed procs.
151-
std::vector<std::string> channels_to_remove;
192+
// Remove dead channels.
193+
std::vector<Channel*> channels_to_remove;
152194
channels_to_remove.reserve(p->channels().size());
153195
for (Channel* channel : p->channels()) {
154-
if (!top_proc_representative_channel.has_value() ||
155-
channel_union.Find(std::string{channel->name()}) !=
156-
*top_proc_representative_channel) {
157-
channels_to_remove.push_back(std::string{channel->name()});
196+
if (!liveness.live_global_channels.contains(channel)) {
197+
channels_to_remove.push_back(channel);
158198
}
159199
}
160-
// Now remove any channels which are only used by now-removed procs.
161-
for (const std::string& channel_name : channels_to_remove) {
162-
XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name));
200+
for (Channel* channel : channels_to_remove) {
163201
XLS_VLOG(2) << "Removing channel: " << channel->name();
164202
XLS_RETURN_IF_ERROR(p->RemoveChannel(channel));
165203
changed = true;

xls/passes/dfe_pass.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,8 @@
2323

2424
namespace xls {
2525

26-
// class DeadCodeEliminationPass iterates up from a functions result
27-
// nodes and marks all visited node. After that, all unvisited nodes
28-
// are considered dead.
26+
// This pass removes unreachable procs/blocks/functions from the package. The
27+
// pass requires `top` be set in order remove any constructs.
2928
class DeadFunctionEliminationPass : public OptimizationPass {
3029
public:
3130
explicit DeadFunctionEliminationPass()

xls/passes/dfe_pass_test.cc

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,22 @@
1616

1717
#include <memory>
1818
#include <string_view>
19+
#include <vector>
1920

2021
#include "gmock/gmock.h"
2122
#include "gtest/gtest.h"
2223
#include "absl/status/statusor.h"
2324
#include "xls/common/status/matchers.h"
25+
#include "xls/common/status/status_macros.h"
2426
#include "xls/ir/bits.h"
27+
#include "xls/ir/channel.h"
2528
#include "xls/ir/function.h"
2629
#include "xls/ir/function_base.h"
2730
#include "xls/ir/function_builder.h"
2831
#include "xls/ir/ir_matcher.h"
2932
#include "xls/ir/ir_test_base.h"
3033
#include "xls/ir/package.h"
34+
#include "xls/ir/value.h"
3135
#include "xls/passes/optimization_pass.h"
3236
#include "xls/passes/pass_base.h"
3337

@@ -55,6 +59,21 @@ class DeadFunctionEliminationPassTest : public IrTestBase {
5559
fb.Param("arg", p->GetBitsType(32));
5660
return fb.Build();
5761
}
62+
63+
absl::StatusOr<Proc*> CreateNewStyleAccumProc(std::string_view proc_name,
64+
Package* package) {
65+
TokenlessProcBuilder pb(NewStyleProc(), proc_name, "tkn", package);
66+
BValue accum = pb.StateElement("accum", Value(UBits(0, 32)));
67+
XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * in_channel,
68+
pb.AddInputChannel("in_ch", package->GetBitsType(32)));
69+
BValue input = pb.Receive(in_channel);
70+
BValue next_accum = pb.Add(accum, input);
71+
XLS_ASSIGN_OR_RETURN(
72+
SendChannelReference * out_channel,
73+
pb.AddOutputChannel("out_ch", package->GetBitsType(32)));
74+
pb.Send(out_channel, next_accum);
75+
return pb.Build({next_accum});
76+
}
5877
};
5978

6079
TEST_F(DeadFunctionEliminationPassTest, NoDeadFunctions) {
@@ -430,5 +449,66 @@ proc test_proc3(tkn: token, state:(), init={()}) {
430449
EXPECT_THAT(p->channels(), IsEmpty());
431450
}
432451

452+
TEST_F(DeadFunctionEliminationPassTest, SingleNewStyleProc) {
453+
auto p = CreatePackage();
454+
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc", p.get()).status());
455+
456+
EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc")));
457+
458+
EXPECT_THAT(Run(p.get()), IsOkAndHolds(false));
459+
460+
EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc")));
461+
}
462+
463+
TEST_F(DeadFunctionEliminationPassTest, MultipleNewStyleProcs) {
464+
auto p = CreatePackage();
465+
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc1", p.get()).status());
466+
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc2", p.get()).status());
467+
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc3", p.get()).status());
468+
XLS_ASSERT_OK(p->SetTopByName("my_proc2"));
469+
470+
EXPECT_THAT(p->GetFunctionBases(),
471+
UnorderedElementsAre(m::Proc("my_proc1"), m::Proc("my_proc2"),
472+
m::Proc("my_proc3")));
473+
474+
EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));
475+
476+
EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc2")));
477+
}
478+
479+
TEST_F(DeadFunctionEliminationPassTest, NewStyleProcWithInstantiations) {
480+
auto p = CreatePackage();
481+
XLS_ASSERT_OK_AND_ASSIGN(Proc * my_proc1,
482+
CreateNewStyleAccumProc("my_proc1", p.get()));
483+
XLS_ASSERT_OK_AND_ASSIGN(Proc * my_proc2,
484+
CreateNewStyleAccumProc("my_proc2", p.get()));
485+
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc3", p.get()).status());
486+
487+
TokenlessProcBuilder pb(NewStyleProc(), "top_proc", "tkn", p.get());
488+
XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences the_channel,
489+
pb.AddChannel("the_channel", p->GetBitsType(32)));
490+
XLS_ASSERT_OK(
491+
pb.InstantiateProc("inst0", my_proc1,
492+
std::vector<ChannelReference*>{the_channel.receive_ref,
493+
the_channel.send_ref}));
494+
XLS_ASSERT_OK(
495+
pb.InstantiateProc("inst1", my_proc2,
496+
std::vector<ChannelReference*>{the_channel.receive_ref,
497+
the_channel.send_ref}));
498+
XLS_ASSERT_OK(pb.Build({}).status());
499+
500+
XLS_ASSERT_OK(p->SetTopByName("top_proc"));
501+
502+
EXPECT_THAT(p->GetFunctionBases(),
503+
UnorderedElementsAre(m::Proc("top_proc"), m::Proc("my_proc1"),
504+
m::Proc("my_proc2"), m::Proc("my_proc3")));
505+
506+
EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));
507+
508+
EXPECT_THAT(p->GetFunctionBases(),
509+
UnorderedElementsAre(m::Proc("top_proc"), m::Proc("my_proc1"),
510+
m::Proc("my_proc2")));
511+
}
512+
433513
} // namespace
434514
} // namespace xls

0 commit comments

Comments
 (0)