Skip to content

Commit 07bd3a3

Browse files
richmckeevercopybara-github
authored andcommitted
Tweaks to VAST type inference logic.
- Constant-fold the expressions within type definitions, for more accurate inference of affected types. - Infer the types of untyped parameters. - Bug fixes such as some previously missing operand recursion. PiperOrigin-RevId: 639118619
1 parent d2da104 commit 07bd3a3

11 files changed

+1142
-30
lines changed

xls/codegen/BUILD

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,12 @@ cc_library(
150150
srcs = ["infer_vast_types.cc"],
151151
hdrs = ["infer_vast_types.h"],
152152
deps = [
153+
":fold_vast_constants",
153154
":vast",
154155
"//xls/common/status:status_macros",
156+
"//xls/ir:source_location",
155157
"@com_google_absl//absl/container:flat_hash_map",
158+
"@com_google_absl//absl/log",
156159
"@com_google_absl//absl/status",
157160
"@com_google_absl//absl/status:statusor",
158161
"@com_google_absl//absl/strings",
@@ -183,6 +186,45 @@ cc_test(
183186
],
184187
)
185188

189+
cc_library(
190+
name = "fold_vast_constants",
191+
srcs = ["fold_vast_constants.cc"],
192+
hdrs = ["fold_vast_constants.h"],
193+
deps = [
194+
":vast",
195+
"//xls/common/status:status_macros",
196+
"//xls/ir:bits",
197+
"//xls/ir:format_preference",
198+
"@com_google_absl//absl/container:flat_hash_map",
199+
"@com_google_absl//absl/functional:any_invocable",
200+
"@com_google_absl//absl/numeric:bits",
201+
"@com_google_absl//absl/status",
202+
"@com_google_absl//absl/status:statusor",
203+
"@com_google_absl//absl/strings",
204+
"@com_google_absl//absl/types:span",
205+
],
206+
)
207+
208+
cc_test(
209+
name = "fold_vast_constants_test",
210+
srcs = ["fold_vast_constants_test.cc"],
211+
deps = [
212+
":fold_vast_constants",
213+
":vast",
214+
"//xls/common:xls_gunit_main",
215+
"//xls/common/status:matchers",
216+
"//xls/common/status:status_macros",
217+
"//xls/ir:bits",
218+
"//xls/ir:bits_ops",
219+
"//xls/ir:format_preference",
220+
"//xls/ir:number_parser",
221+
"//xls/ir:source_location",
222+
"@com_google_absl//absl/log:check",
223+
"@com_google_absl//absl/status:statusor",
224+
"@com_google_googletest//:gtest",
225+
],
226+
)
227+
186228
cc_library(
187229
name = "finite_state_machine",
188230
srcs = ["finite_state_machine.cc"],

xls/codegen/fold_vast_constants.cc

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
// Copyright 2024 The XLS Authors
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "xls/codegen/fold_vast_constants.h"
16+
17+
#include <cstdint>
18+
#include <memory>
19+
#include <optional>
20+
#include <utility>
21+
#include <vector>
22+
23+
#include "absl/container/flat_hash_map.h"
24+
#include "absl/numeric/bits.h"
25+
#include "absl/status/status.h"
26+
#include "absl/status/statusor.h"
27+
#include "absl/strings/str_cat.h"
28+
#include "absl/types/span.h"
29+
#include "xls/codegen/vast.h"
30+
#include "xls/common/status/status_macros.h"
31+
#include "xls/ir/bits.h"
32+
#include "xls/ir/format_preference.h"
33+
34+
namespace xls {
35+
namespace verilog {
36+
namespace {
37+
38+
// Based upon the util version.
39+
uint64_t Log2Ceiling(uint64_t n) {
40+
int floor = (absl::bit_width(n) - 1);
41+
return floor + (((n & (n - 1)) == 0) ? 0 : 1);
42+
}
43+
44+
// Helper class used internally by the exposed API to bind a type map and avoid
45+
// passing it around.
46+
class ConstantFoldingContext {
47+
public:
48+
explicit ConstantFoldingContext(
49+
const absl::flat_hash_map<Expression*, DataType*>& type_map)
50+
: type_map_(type_map) {}
51+
52+
absl::StatusOr<Expression*> FoldConstants(Expression* expr) {
53+
if (expr->IsLiteral()) {
54+
return expr;
55+
}
56+
if (auto* ref = dynamic_cast<ParameterRef*>(expr);
57+
ref && ref->parameter()->rhs()) {
58+
return FoldConstants(ref->parameter()->rhs());
59+
}
60+
if (auto* ref = dynamic_cast<EnumMemberRef*>(expr);
61+
ref && ref->member()->rhs()) {
62+
return FoldConstants(ref->member()->rhs());
63+
}
64+
if (auto* op = dynamic_cast<BinaryInfix*>(expr); op) {
65+
return FoldBinaryOp(op);
66+
}
67+
if (auto* ternary = dynamic_cast<Ternary*>(expr); ternary) {
68+
absl::StatusOr<int64_t> test_value = FoldEntireExpr(ternary->test());
69+
absl::StatusOr<int64_t> consequent_value =
70+
FoldEntireExpr(ternary->consequent());
71+
absl::StatusOr<int64_t> alternate_value =
72+
FoldEntireExpr(ternary->alternate());
73+
if (test_value.ok() && consequent_value.ok() && alternate_value.ok()) {
74+
return *test_value ? MakeFoldedConstant(ternary, *consequent_value)
75+
: MakeFoldedConstant(ternary, *alternate_value);
76+
}
77+
}
78+
if (auto* call = dynamic_cast<SystemFunctionCall*>(expr); call) {
79+
if (call->name() == "clog2" && call->args().has_value() &&
80+
(*call->args()).size() == 1) {
81+
absl::StatusOr<int64_t> arg_value = FoldEntireExpr((*call->args())[0]);
82+
if (arg_value.ok()) {
83+
return MakeFoldedConstant(call, Log2Ceiling(*arg_value));
84+
}
85+
}
86+
}
87+
return expr;
88+
}
89+
90+
absl::StatusOr<int64_t> FoldEntireExpr(Expression* expr) {
91+
XLS_ASSIGN_OR_RETURN(Expression * folded, FoldConstants(expr));
92+
if (folded->IsLiteral()) {
93+
return folded->AsLiteralOrDie()->ToInt64();
94+
}
95+
return absl::InvalidArgumentError(
96+
absl::StrCat("Expression does not entirely fold to a constant: ",
97+
expr->Emit(nullptr)));
98+
}
99+
100+
absl::StatusOr<Literal*> MakeFoldedConstant(Expression* original,
101+
int64_t value) {
102+
int64_t effective_size = 32;
103+
bool is_signed = true;
104+
const auto it = type_map_.find(original);
105+
if (it != type_map_.end()) {
106+
XLS_ASSIGN_OR_RETURN(effective_size, it->second->FlatBitCountAsInt64());
107+
is_signed = it->second->is_signed();
108+
}
109+
return original->file()->Make<Literal>(
110+
original->loc(),
111+
is_signed ? SBits(value, effective_size)
112+
: UBits(static_cast<uint64_t>(value), effective_size),
113+
value < 0 ? FormatPreference::kHex : FormatPreference::kUnsignedDecimal,
114+
/*effective_bit_count=*/effective_size,
115+
/*emit_bit_count=*/effective_size != 32,
116+
/*declared_as_signed=*/is_signed);
117+
}
118+
119+
absl::StatusOr<DataType*> FoldConstants(DataType* data_type) {
120+
if (data_type->FlatBitCountAsInt64().ok()) {
121+
return data_type;
122+
}
123+
if (auto* bit_vector_type = dynamic_cast<BitVectorType*>(data_type);
124+
bit_vector_type && !bit_vector_type->size_expr()->IsLiteral()) {
125+
XLS_ASSIGN_OR_RETURN(int64_t folded_size,
126+
FoldEntireExpr(bit_vector_type->size_expr()));
127+
return data_type->file()->Make<BitVectorType>(
128+
data_type->loc(),
129+
data_type->file()->PlainLiteral(static_cast<int32_t>(folded_size),
130+
data_type->loc()),
131+
data_type->is_signed(),
132+
/*size_expr_is_max=*/bit_vector_type->max().has_value());
133+
}
134+
if (auto* array_type = dynamic_cast<UnpackedArrayType*>(data_type);
135+
array_type) {
136+
XLS_ASSIGN_OR_RETURN(std::vector<int64_t> folded_dims,
137+
FoldDims(array_type->dims()));
138+
XLS_ASSIGN_OR_RETURN(DataType * folded_element_type,
139+
FoldConstants(array_type->element_type()));
140+
return array_type->file()->template Make<UnpackedArrayType>(
141+
data_type->loc(), folded_element_type, folded_dims);
142+
}
143+
if (auto* array_type = dynamic_cast<PackedArrayType*>(data_type);
144+
array_type) {
145+
XLS_ASSIGN_OR_RETURN(std::vector<int64_t> folded_dims,
146+
FoldDims(array_type->dims()));
147+
XLS_ASSIGN_OR_RETURN(DataType * folded_element_type,
148+
FoldConstants(array_type->element_type()));
149+
return array_type->file()->template Make<PackedArrayType>(
150+
data_type->loc(), folded_element_type, folded_dims,
151+
/*dims_are_max=*/array_type->dims_are_max());
152+
}
153+
if (auto* enum_def = dynamic_cast<Enum*>(data_type); enum_def) {
154+
XLS_ASSIGN_OR_RETURN(DataType * folded_base_type,
155+
FoldConstants(enum_def->BaseType()));
156+
return data_type->file()->Make<Enum>(enum_def->loc(), enum_def->kind(),
157+
folded_base_type,
158+
enum_def->members());
159+
}
160+
if (auto* struct_def = dynamic_cast<Struct*>(data_type); struct_def) {
161+
XLS_ASSIGN_OR_RETURN(std::vector<Def*> folded_member_defs,
162+
FoldTypesOfDefs(struct_def->members()));
163+
return struct_def->file()->Make<Struct>(struct_def->loc(),
164+
folded_member_defs);
165+
}
166+
if (auto* type_def = dynamic_cast<TypedefType*>(data_type); type_def) {
167+
return FoldConstants(type_def->BaseType());
168+
}
169+
return absl::InternalError(absl::StrCat("Could not constant-fold type: ",
170+
data_type->Emit(nullptr)));
171+
}
172+
173+
private:
174+
absl::StatusOr<std::vector<int64_t>> FoldDims(
175+
absl::Span<Expression* const> dims) {
176+
std::vector<int64_t> result(dims.size());
177+
result.reserve(dims.size());
178+
int i = 0;
179+
for (Expression* dim : dims) {
180+
XLS_ASSIGN_OR_RETURN(result[i++], FoldEntireExpr(dim));
181+
}
182+
return result;
183+
}
184+
185+
absl::StatusOr<std::vector<Def*>> FoldTypesOfDefs(
186+
absl::Span<Def* const> defs) {
187+
std::vector<Def*> result(defs.size());
188+
int i = 0;
189+
for (Def* def : defs) {
190+
XLS_ASSIGN_OR_RETURN(DataType * folded_type,
191+
FoldConstants(def->data_type()));
192+
result[i++] = def->file()->Make<Def>(
193+
def->loc(), def->GetName(), def->data_kind(), folded_type,
194+
def->init().has_value() ? *def->init() : nullptr);
195+
}
196+
return result;
197+
}
198+
199+
absl::StatusOr<Expression*> FoldBinaryOp(Operator* op) {
200+
auto* binop = dynamic_cast<BinaryInfix*>(op);
201+
XLS_ASSIGN_OR_RETURN(Expression * folded_lhs, FoldConstants(binop->lhs()));
202+
XLS_ASSIGN_OR_RETURN(Expression * folded_rhs, FoldConstants(binop->rhs()));
203+
if (!folded_lhs->IsLiteral() || !folded_rhs->IsLiteral()) {
204+
return op->file()->Make<BinaryInfix>(op->loc(), folded_lhs, folded_rhs,
205+
op->kind());
206+
}
207+
Literal* lhs_literal = folded_lhs->AsLiteralOrDie();
208+
Literal* rhs_literal = folded_rhs->AsLiteralOrDie();
209+
XLS_ASSIGN_OR_RETURN(int64_t lhs_value,
210+
folded_lhs->AsLiteralOrDie()->ToInt64());
211+
XLS_ASSIGN_OR_RETURN(int64_t rhs_value,
212+
folded_rhs->AsLiteralOrDie()->ToInt64());
213+
bool signed_input = lhs_literal->is_declared_as_signed() &&
214+
rhs_literal->is_declared_as_signed();
215+
std::optional<bool> bool_result;
216+
std::optional<int64_t> int_result;
217+
switch (op->kind()) {
218+
case OperatorKind::kAdd:
219+
int_result = lhs_value + rhs_value;
220+
break;
221+
case OperatorKind::kSub:
222+
int_result = lhs_value - rhs_value;
223+
break;
224+
case OperatorKind::kMul:
225+
int_result = lhs_value * rhs_value;
226+
break;
227+
case OperatorKind::kDiv:
228+
int_result = lhs_value / rhs_value;
229+
break;
230+
case OperatorKind::kMod:
231+
int_result = lhs_value % rhs_value;
232+
break;
233+
case OperatorKind::kEq:
234+
bool_result = lhs_value == rhs_value;
235+
break;
236+
case OperatorKind::kNe:
237+
bool_result = lhs_value != rhs_value;
238+
break;
239+
case OperatorKind::kGe:
240+
bool_result = signed_input ? lhs_value >= rhs_value
241+
: static_cast<uint64_t>(lhs_value) >=
242+
static_cast<uint64_t>(rhs_value);
243+
break;
244+
case OperatorKind::kGt:
245+
bool_result = signed_input ? lhs_value > rhs_value
246+
: static_cast<uint64_t>(lhs_value) >
247+
static_cast<uint64_t>(rhs_value);
248+
break;
249+
case OperatorKind::kLe:
250+
bool_result = signed_input ? lhs_value <= rhs_value
251+
: static_cast<uint64_t>(lhs_value) <=
252+
static_cast<uint64_t>(rhs_value);
253+
break;
254+
case OperatorKind::kLt:
255+
bool_result = signed_input ? lhs_value < rhs_value
256+
: static_cast<uint64_t>(lhs_value) <
257+
static_cast<uint64_t>(rhs_value);
258+
break;
259+
default:
260+
break;
261+
}
262+
if (int_result.has_value()) {
263+
return MakeFoldedConstant(op, *int_result);
264+
}
265+
if (bool_result.has_value()) {
266+
return MakeFoldedConstant(op, static_cast<int64_t>(*bool_result));
267+
}
268+
return op->file()->Make<BinaryInfix>(op->loc(), folded_lhs, folded_rhs,
269+
op->kind());
270+
}
271+
272+
const absl::flat_hash_map<Expression*, DataType*>& type_map_;
273+
};
274+
275+
} // namespace
276+
277+
absl::StatusOr<DataType*> FoldVastConstants(
278+
DataType* data_type,
279+
const absl::flat_hash_map<Expression*, DataType*>& type_map) {
280+
auto context = std::make_unique<ConstantFoldingContext>(type_map);
281+
return context->FoldConstants(data_type);
282+
}
283+
284+
absl::StatusOr<Expression*> FoldVastConstants(
285+
Expression* expr,
286+
const absl::flat_hash_map<Expression*, DataType*>& type_map) {
287+
auto context = std::make_unique<ConstantFoldingContext>(type_map);
288+
return context->FoldConstants(expr);
289+
}
290+
291+
} // namespace verilog
292+
} // namespace xls

0 commit comments

Comments
 (0)