Skip to content

Commit 749c04e

Browse files
allightcopybara-github
authored andcommitted
Register sharing pass
Add a pass to share registers between stages identified by register-chaining analysis. PiperOrigin-RevId: 612635721
1 parent 106116a commit 749c04e

File tree

4 files changed

+613
-0
lines changed

4 files changed

+613
-0
lines changed

xls/codegen/BUILD

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,28 @@ cc_library(
628628
],
629629
)
630630

631+
cc_library(
632+
name = "register_combining_pass",
633+
srcs = ["register_combining_pass.cc"],
634+
hdrs = ["register_combining_pass.h"],
635+
deps = [
636+
":codegen_pass",
637+
":register_chaining_analysis",
638+
"//xls/common/logging",
639+
"//xls/common/status:ret_check",
640+
"//xls/common/status:status_macros",
641+
"//xls/ir",
642+
"//xls/ir:register",
643+
"//xls/passes:pass_base",
644+
"@com_google_absl//absl/algorithm:container",
645+
"@com_google_absl//absl/container:flat_hash_set",
646+
"@com_google_absl//absl/log:check",
647+
"@com_google_absl//absl/status",
648+
"@com_google_absl//absl/status:statusor",
649+
"@com_google_absl//absl/types:span",
650+
],
651+
)
652+
631653
cc_library(
632654
name = "register_legalization_pass",
633655
srcs = ["register_legalization_pass.cc"],
@@ -947,6 +969,35 @@ cc_test(
947969
],
948970
)
949971

972+
cc_test(
973+
name = "register_combining_pass_test",
974+
srcs = ["register_combining_pass_test.cc"],
975+
deps = [
976+
":block_conversion",
977+
":codegen_options",
978+
":codegen_pass",
979+
":module_signature_cc_proto",
980+
":register_combining_pass",
981+
"//xls/common:xls_gunit",
982+
"//xls/common:xls_gunit_main",
983+
"//xls/common/status:matchers",
984+
"//xls/ir",
985+
"//xls/ir:bits",
986+
"//xls/ir:function_builder",
987+
"//xls/ir:ir_matcher",
988+
"//xls/ir:ir_test_base",
989+
"//xls/ir:op",
990+
"//xls/ir:register",
991+
"//xls/ir:source_location",
992+
"//xls/passes:pass_base",
993+
"//xls/scheduling:pipeline_schedule",
994+
"//xls/tools:codegen_flags_cc_proto",
995+
"@com_google_absl//absl/status:statusor",
996+
"@com_google_absl//absl/strings",
997+
"@com_google_absl//absl/strings:str_format",
998+
],
999+
)
1000+
9501001
cc_test(
9511002
name = "register_legalization_pass_test",
9521003
srcs = ["register_legalization_pass_test.cc"],
Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
// Copyright 2024 The XLS Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "xls/codegen/register_combining_pass.h"
16+
17+
#include <vector>
18+
19+
#include "absl/algorithm/container.h"
20+
#include "absl/container/flat_hash_set.h"
21+
#include "absl/log/check.h"
22+
#include "absl/status/status.h"
23+
#include "absl/status/statusor.h"
24+
#include "absl/types/span.h"
25+
#include "xls/codegen/codegen_pass.h"
26+
#include "xls/codegen/register_chaining_analysis.h"
27+
#include "xls/common/logging/logging.h"
28+
#include "xls/common/status/ret_check.h"
29+
#include "xls/common/status/status_macros.h"
30+
#include "xls/ir/node.h"
31+
#include "xls/ir/nodes.h"
32+
#include "xls/ir/register.h"
33+
#include "xls/passes/pass_base.h"
34+
35+
namespace xls::verilog {
36+
37+
namespace {
38+
absl::Status CombineRegisters(absl::Span<const RegisterData> mutex_group,
39+
CodegenPassUnit* unit) {
40+
XLS_RET_CHECK_GE(mutex_group.size(), 2)
41+
<< "Attempting to combine a single register is not meaningful. Single "
42+
"element mutex groups should have been filtered out.";
43+
// Registers are listed so that 'last' one is at the end.
44+
// The register with a loop-back write (write from a later stage) is always
45+
// at the front, if one exists.
46+
// Merge from the front back.
47+
const RegisterData& first = mutex_group.front();
48+
std::vector<Node*> cleanup_nodes;
49+
absl::flat_hash_set<Register*> cleanup_regs;
50+
51+
// No need to change load-enable bits, we're merging into the top which has
52+
// the right bits already.
53+
XLS_VLOG(2) << "Collapsing " << mutex_group.size() << " registers into "
54+
<< mutex_group.front().reg->ToString();
55+
for (const RegisterData& merge : mutex_group.subspan(1)) {
56+
XLS_RETURN_IF_ERROR(merge.read->ReplaceUsesWith(first.read));
57+
cleanup_regs.insert(merge.reg);
58+
cleanup_nodes.push_back(merge.read);
59+
cleanup_nodes.push_back(merge.write);
60+
}
61+
62+
// Do cleanup.
63+
for (auto& stage : unit->streaming_io_and_pipeline.pipeline_registers) {
64+
std::erase_if(stage, [&](const PipelineRegister& pr) {
65+
return cleanup_regs.contains(pr.reg);
66+
});
67+
}
68+
for (auto& state_reg : unit->streaming_io_and_pipeline.state_registers) {
69+
CHECK(!state_reg || !cleanup_regs.contains(state_reg->reg))
70+
<< "Removed a state register: " << state_reg->reg->ToString();
71+
}
72+
for (Node* n : cleanup_nodes) {
73+
XLS_RETURN_IF_ERROR(unit->block->RemoveNode(n)) << "can't remove " << n;
74+
}
75+
for (Register* r : cleanup_regs) {
76+
XLS_RETURN_IF_ERROR(unit->block->RemoveRegister(r));
77+
}
78+
return absl::OkStatus();
79+
}
80+
} // namespace
81+
82+
absl::StatusOr<bool> RegisterCombiningPass::RunInternal(
83+
CodegenPassUnit* unit, const CodegenPassOptions& options,
84+
PassResults* results) const {
85+
if (!unit->concurrent_stages) {
86+
return false;
87+
}
88+
std::vector<RegisterData> candidate_registers;
89+
candidate_registers.reserve(unit->block->GetRegisters().size());
90+
// State registers (but not their valid/reset regs) are candidates for
91+
// merging.
92+
XLS_VLOG(2) << unit->block->DumpIr();
93+
for (const auto& maybe_reg :
94+
unit->streaming_io_and_pipeline.state_registers) {
95+
if (maybe_reg) {
96+
CHECK(!maybe_reg->next_values.empty());
97+
auto write_stage =
98+
absl::c_min_element(maybe_reg->next_values, [](const auto& l,
99+
const auto& r) {
100+
return l.stage < r.stage;
101+
})->stage;
102+
if (maybe_reg->read_stage == write_stage) {
103+
// Immediate back edge.
104+
continue;
105+
}
106+
candidate_registers.push_back({.reg = maybe_reg->reg,
107+
.read = maybe_reg->reg_read,
108+
.read_stage = maybe_reg->read_stage,
109+
.write = maybe_reg->reg_write,
110+
.write_stage = write_stage});
111+
}
112+
}
113+
// pipeline registers (but not their valid/reset regs) are candidates for
114+
// merging.
115+
for (const auto& stg_regs :
116+
unit->streaming_io_and_pipeline.pipeline_registers) {
117+
for (const auto& reg : stg_regs) {
118+
CHECK(unit->streaming_io_and_pipeline.node_to_stage_map.contains(
119+
reg.reg_read))
120+
<< reg.reg_read;
121+
CHECK(unit->streaming_io_and_pipeline.node_to_stage_map.contains(
122+
reg.reg_write))
123+
<< reg.reg_write;
124+
Stage read_stage =
125+
unit->streaming_io_and_pipeline.node_to_stage_map.at(reg.reg_read);
126+
Stage write_stage =
127+
unit->streaming_io_and_pipeline.node_to_stage_map.at(reg.reg_write);
128+
CHECK_EQ(write_stage + 1, read_stage)
129+
<< "pipeline register skipping stage? " << reg.reg->ToString()
130+
<< "\nread: " << reg.reg_read << "\nwrite: " << reg.reg_write;
131+
candidate_registers.push_back({
132+
.reg = reg.reg,
133+
.read = reg.reg_read,
134+
.read_stage = read_stage,
135+
.write = reg.reg_write,
136+
.write_stage = write_stage,
137+
});
138+
}
139+
}
140+
// chains of registers which are possibly combinable.
141+
RegisterChains reg_groups;
142+
143+
for (const RegisterData& rd : candidate_registers) {
144+
reg_groups.InsertAndReduce(rd);
145+
}
146+
XLS_ASSIGN_OR_RETURN(
147+
std::vector<std::vector<RegisterData>> mutex_chains,
148+
reg_groups.SplitBetweenMutexRegions(*unit->concurrent_stages, options));
149+
bool changed = !mutex_chains.empty();
150+
151+
for (const std::vector<RegisterData>& group : mutex_chains) {
152+
XLS_RETURN_IF_ERROR(CombineRegisters(group, unit));
153+
}
154+
155+
if (changed) {
156+
unit->GcMetadata();
157+
}
158+
159+
return changed;
160+
}
161+
162+
} // namespace xls::verilog

xls/codegen/register_combining_pass.h

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright 2024 The XLS Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef XLS_CODEGEN_REGISTER_COMBINING_PASS_H_
16+
#define XLS_CODEGEN_REGISTER_COMBINING_PASS_H_
17+
18+
#include "absl/status/statusor.h"
19+
#include "xls/codegen/codegen_pass.h"
20+
#include "xls/passes/pass_base.h"
21+
22+
namespace xls::verilog {
23+
// Eliminates (and removes) redundant registers by allowing registers to be
24+
// shared across many stages.
25+
class RegisterCombiningPass : public CodegenPass {
26+
public:
27+
RegisterCombiningPass()
28+
: CodegenPass("register_combining",
29+
"Combine mutually exclusive registers") {}
30+
~RegisterCombiningPass() override = default;
31+
32+
absl::StatusOr<bool> RunInternal(CodegenPassUnit* unit,
33+
const CodegenPassOptions& options,
34+
PassResults* results) const override;
35+
};
36+
37+
} // namespace xls::verilog
38+
39+
#endif // XLS_CODEGEN_REGISTER_COMBINING_PASS_H_

0 commit comments

Comments
 (0)