Skip to content

Commit

Permalink
Add support for new-style procs to dead function elimination pass.
Browse files Browse the repository at this point in the history
For new-style procs, elaboration trivially determines proc liveness.

PiperOrigin-RevId: 597682240
  • Loading branch information
meheffernan authored and copybara-github committed Jan 12, 2024
1 parent 27c3660 commit 2766e34
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 54 deletions.
4 changes: 4 additions & 0 deletions xls/passes/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,7 @@ cc_library(
"//xls/data_structures:union_find",
"//xls/ir",
"//xls/ir:channel",
"//xls/ir:elaboration",
"//xls/ir:node_util",
"//xls/ir:op",
],
Expand Down Expand Up @@ -2081,11 +2082,14 @@ cc_test(
"//xls/common:xls_gunit",
"//xls/common:xls_gunit_main",
"//xls/common/status:matchers",
"//xls/common/status:status_macros",
"//xls/ir",
"//xls/ir:bits",
"//xls/ir:channel",
"//xls/ir:function_builder",
"//xls/ir:ir_matcher",
"//xls/ir:ir_test_base",
"//xls/ir:value",
],
)

Expand Down
140 changes: 89 additions & 51 deletions xls/passes/dfe_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,20 @@
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "xls/common/logging/logging.h"
#include "xls/common/status/status_macros.h"
#include "xls/data_structures/union_find.h"
#include "xls/ir/block.h"
#include "xls/ir/channel.h"
#include "xls/ir/elaboration.h"
#include "xls/ir/function_base.h"
#include "xls/ir/instantiation.h"
#include "xls/ir/node_util.h"
#include "xls/ir/op.h"
#include "xls/ir/package.h"
Expand Down Expand Up @@ -79,63 +80,104 @@ void MarkReachedFunctions(FunctionBase* func,
}
}
}
} // namespace

// Starting from the return_value(s), DFS over all nodes. Unvisited
// nodes, or parameters, are dead.
absl::StatusOr<bool> DeadFunctionEliminationPass::RunInternal(
Package* p, const OptimizationPassOptions& options,
PassResults* results) const {
std::optional<FunctionBase*> top = p->GetTop();
if (!top.has_value()) {
return false;
// Data structure describing the liveness of global constructs in a package.
struct FunctionBaseLiveness {
// The live roots of the package. This does not include FunctionBases which
// are live because they are called/instantiated from other FunctionBases.
std::vector<FunctionBase*> live_roots;

// Set of the live global channels. Only set for old-style procs.
absl::flat_hash_set<Channel*> live_global_channels;
};

absl::StatusOr<FunctionBaseLiveness> LivenessFromTopProc(Proc* top) {
if (top->is_new_style_proc()) {
XLS_ASSIGN_OR_RETURN(Elaboration elab, Elaboration::Elaborate(top));
return FunctionBaseLiveness{.live_roots = std::vector<FunctionBase*>(
elab.procs().begin(), elab.procs().end()),
.live_global_channels = {}};
}

// Mapping from proc->channel, where channel is a representative value
// for all the channel names in the UnionFind.
absl::flat_hash_map<Proc*, std::string> representative_channels;
Package* p = top->package();

// Mapping from proc to channel, where channel is a representative value for
// all the channel names in the UnionFind. If the proc uses no channels then
// the value will be nullopt.
absl::flat_hash_map<Proc*, std::optional<std::string_view>>
representative_channels;
representative_channels.reserve(p->procs().size());
// Channels in the same proc will be union'd.
UnionFind<std::string> channel_union;
for (std::unique_ptr<Proc>& proc : p->procs()) {
std::optional<std::string> representative_proc_channel;
UnionFind<std::string_view> channel_union;
for (const std::unique_ptr<Proc>& proc : p->procs()) {
std::optional<std::string_view> representative_proc_channel;
for (Node* node : proc->nodes()) {
if (IsChannelNode(node)) {
std::string channel;
if (node->Is<Send>()) {
channel = node->As<Send>()->channel_name();
} else if (node->Is<Receive>()) {
channel = node->As<Receive>()->channel_name();
} else {
return absl::NotFoundError(absl::StrFormat(
"No channel associated with node %s", node->GetName()));
}
channel_union.Insert(channel);
XLS_ASSIGN_OR_RETURN(Channel * channel, GetChannelUsedByNode(node));
channel_union.Insert(channel->name());
if (representative_proc_channel.has_value()) {
channel_union.Union(representative_proc_channel.value(), channel);
channel_union.Union(representative_proc_channel.value(),
channel->name());
} else {
representative_proc_channel = channel;
representative_channels.insert({proc.get(), channel});
representative_proc_channel = channel->name();
}
}
}
representative_channels[proc.get()] = representative_proc_channel;
}

absl::flat_hash_set<FunctionBase*> reached;
MarkReachedFunctions(top.value(), &reached);
std::optional<std::string> top_proc_representative_channel;
if ((*top)->IsProc()) {
auto itr = representative_channels.find(top.value()->AsProcOrDie());
if (itr != representative_channels.end()) {
top_proc_representative_channel = channel_union.Find(itr->second);
for (auto [proc, representative_channel] : representative_channels) {
if (channel_union.Find(representative_channel) ==
*top_proc_representative_channel) {
MarkReachedFunctions(proc, &reached);
}
FunctionBaseLiveness liveness;

// Add procs to the live set if they are connnected to `top` via channels.
for (const std::unique_ptr<Proc>& proc : p->procs()) {
if (proc.get() == top) {
liveness.live_roots.push_back(proc.get());
continue;
}
if (representative_channels.at(top).has_value() &&
representative_channels.at(proc.get()) &&
channel_union.Find(representative_channels.at(top).value()) ==
channel_union.Find(
representative_channels.at(proc.get()).value())) {
liveness.live_roots.push_back(proc.get());
}
}

// Add channels to the live set if they are connnected to `top`.
if (representative_channels.at(top).has_value()) {
for (Channel* channel : p->channels()) {
if (channel_union.Find(channel->name()) ==
channel_union.Find(representative_channels.at(top).value())) {
liveness.live_global_channels.insert(channel);
}
}
}
return liveness;
}

} // namespace

// Starting from the return_value(s), DFS over all nodes. Unvisited
// nodes, or parameters, are dead.
absl::StatusOr<bool> DeadFunctionEliminationPass::RunInternal(
Package* p, const OptimizationPassOptions& options,
PassResults* results) const {
std::optional<FunctionBase*> top = p->GetTop();
if (!top.has_value()) {
return false;
}

FunctionBaseLiveness liveness;
if ((*top)->IsProc()) {
XLS_ASSIGN_OR_RETURN(liveness, LivenessFromTopProc((*top)->AsProcOrDie()));
} else {
liveness.live_roots = {*top};
}

absl::flat_hash_set<FunctionBase*> reached;
for (FunctionBase* fb : liveness.live_roots) {
MarkReachedFunctions(fb, &reached);
}

// Accumulate a list of FunctionBases to unlink.
bool changed = false;
Expand All @@ -147,19 +189,15 @@ absl::StatusOr<bool> DeadFunctionEliminationPass::RunInternal(
}
}

// Find any channels which are only used by now-removed procs.
std::vector<std::string> channels_to_remove;
// Remove dead channels.
std::vector<Channel*> channels_to_remove;
channels_to_remove.reserve(p->channels().size());
for (Channel* channel : p->channels()) {
if (!top_proc_representative_channel.has_value() ||
channel_union.Find(std::string{channel->name()}) !=
*top_proc_representative_channel) {
channels_to_remove.push_back(std::string{channel->name()});
if (!liveness.live_global_channels.contains(channel)) {
channels_to_remove.push_back(channel);
}
}
// Now remove any channels which are only used by now-removed procs.
for (const std::string& channel_name : channels_to_remove) {
XLS_ASSIGN_OR_RETURN(Channel * channel, p->GetChannel(channel_name));
for (Channel* channel : channels_to_remove) {
XLS_VLOG(2) << "Removing channel: " << channel->name();
XLS_RETURN_IF_ERROR(p->RemoveChannel(channel));
changed = true;
Expand Down
5 changes: 2 additions & 3 deletions xls/passes/dfe_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,8 @@

namespace xls {

// class DeadCodeEliminationPass iterates up from a functions result
// nodes and marks all visited node. After that, all unvisited nodes
// are considered dead.
// This pass removes unreachable procs/blocks/functions from the package. The
// pass requires `top` be set in order remove any constructs.
class DeadFunctionEliminationPass : public OptimizationPass {
public:
explicit DeadFunctionEliminationPass()
Expand Down
80 changes: 80 additions & 0 deletions xls/passes/dfe_pass_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,22 @@

#include <memory>
#include <string_view>
#include <vector>

#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "absl/status/statusor.h"
#include "xls/common/status/matchers.h"
#include "xls/common/status/status_macros.h"
#include "xls/ir/bits.h"
#include "xls/ir/channel.h"
#include "xls/ir/function.h"
#include "xls/ir/function_base.h"
#include "xls/ir/function_builder.h"
#include "xls/ir/ir_matcher.h"
#include "xls/ir/ir_test_base.h"
#include "xls/ir/package.h"
#include "xls/ir/value.h"
#include "xls/passes/optimization_pass.h"
#include "xls/passes/pass_base.h"

Expand Down Expand Up @@ -55,6 +59,21 @@ class DeadFunctionEliminationPassTest : public IrTestBase {
fb.Param("arg", p->GetBitsType(32));
return fb.Build();
}

absl::StatusOr<Proc*> CreateNewStyleAccumProc(std::string_view proc_name,
Package* package) {
TokenlessProcBuilder pb(NewStyleProc(), proc_name, "tkn", package);
BValue accum = pb.StateElement("accum", Value(UBits(0, 32)));
XLS_ASSIGN_OR_RETURN(ReceiveChannelReference * in_channel,
pb.AddInputChannel("in_ch", package->GetBitsType(32)));
BValue input = pb.Receive(in_channel);
BValue next_accum = pb.Add(accum, input);
XLS_ASSIGN_OR_RETURN(
SendChannelReference * out_channel,
pb.AddOutputChannel("out_ch", package->GetBitsType(32)));
pb.Send(out_channel, next_accum);
return pb.Build({next_accum});
}
};

TEST_F(DeadFunctionEliminationPassTest, NoDeadFunctions) {
Expand Down Expand Up @@ -430,5 +449,66 @@ proc test_proc3(tkn: token, state:(), init={()}) {
EXPECT_THAT(p->channels(), IsEmpty());
}

TEST_F(DeadFunctionEliminationPassTest, SingleNewStyleProc) {
auto p = CreatePackage();
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc", p.get()).status());

EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc")));

EXPECT_THAT(Run(p.get()), IsOkAndHolds(false));

EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc")));
}

TEST_F(DeadFunctionEliminationPassTest, MultipleNewStyleProcs) {
auto p = CreatePackage();
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc1", p.get()).status());
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc2", p.get()).status());
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc3", p.get()).status());
XLS_ASSERT_OK(p->SetTopByName("my_proc2"));

EXPECT_THAT(p->GetFunctionBases(),
UnorderedElementsAre(m::Proc("my_proc1"), m::Proc("my_proc2"),
m::Proc("my_proc3")));

EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));

EXPECT_THAT(p->GetFunctionBases(), UnorderedElementsAre(m::Proc("my_proc2")));
}

TEST_F(DeadFunctionEliminationPassTest, NewStyleProcWithInstantiations) {
auto p = CreatePackage();
XLS_ASSERT_OK_AND_ASSIGN(Proc * my_proc1,
CreateNewStyleAccumProc("my_proc1", p.get()));
XLS_ASSERT_OK_AND_ASSIGN(Proc * my_proc2,
CreateNewStyleAccumProc("my_proc2", p.get()));
XLS_ASSERT_OK(CreateNewStyleAccumProc("my_proc3", p.get()).status());

TokenlessProcBuilder pb(NewStyleProc(), "top_proc", "tkn", p.get());
XLS_ASSERT_OK_AND_ASSIGN(ChannelReferences the_channel,
pb.AddChannel("the_channel", p->GetBitsType(32)));
XLS_ASSERT_OK(
pb.InstantiateProc("inst0", my_proc1,
std::vector<ChannelReference*>{the_channel.receive_ref,
the_channel.send_ref}));
XLS_ASSERT_OK(
pb.InstantiateProc("inst1", my_proc2,
std::vector<ChannelReference*>{the_channel.receive_ref,
the_channel.send_ref}));
XLS_ASSERT_OK(pb.Build({}).status());

XLS_ASSERT_OK(p->SetTopByName("top_proc"));

EXPECT_THAT(p->GetFunctionBases(),
UnorderedElementsAre(m::Proc("top_proc"), m::Proc("my_proc1"),
m::Proc("my_proc2"), m::Proc("my_proc3")));

EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));

EXPECT_THAT(p->GetFunctionBases(),
UnorderedElementsAre(m::Proc("top_proc"), m::Proc("my_proc1"),
m::Proc("my_proc2")));
}

} // namespace
} // namespace xls

0 comments on commit 2766e34

Please sign in to comment.