diff --git a/xls/codegen/BUILD b/xls/codegen/BUILD index c4a17708b9..7731af7147 100644 --- a/xls/codegen/BUILD +++ b/xls/codegen/BUILD @@ -150,9 +150,12 @@ cc_library( srcs = ["infer_vast_types.cc"], hdrs = ["infer_vast_types.h"], deps = [ + ":fold_vast_constants", ":vast", "//xls/common/status:status_macros", + "//xls/ir:source_location", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -183,6 +186,45 @@ cc_test( ], ) +cc_library( + name = "fold_vast_constants", + srcs = ["fold_vast_constants.cc"], + hdrs = ["fold_vast_constants.h"], + deps = [ + ":vast", + "//xls/common/status:status_macros", + "//xls/ir:bits", + "//xls/ir:format_preference", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/functional:any_invocable", + "@com_google_absl//absl/numeric:bits", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], +) + +cc_test( + name = "fold_vast_constants_test", + srcs = ["fold_vast_constants_test.cc"], + deps = [ + ":fold_vast_constants", + ":vast", + "//xls/common:xls_gunit_main", + "//xls/common/status:matchers", + "//xls/common/status:status_macros", + "//xls/ir:bits", + "//xls/ir:bits_ops", + "//xls/ir:format_preference", + "//xls/ir:number_parser", + "//xls/ir:source_location", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "finite_state_machine", srcs = ["finite_state_machine.cc"], diff --git a/xls/codegen/fold_vast_constants.cc b/xls/codegen/fold_vast_constants.cc new file mode 100644 index 0000000000..ae7bf772ae --- /dev/null +++ b/xls/codegen/fold_vast_constants.cc @@ -0,0 +1,292 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/codegen/fold_vast_constants.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/numeric/bits.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/types/span.h" +#include "xls/codegen/vast.h" +#include "xls/common/status/status_macros.h" +#include "xls/ir/bits.h" +#include "xls/ir/format_preference.h" + +namespace xls { +namespace verilog { +namespace { + +// Based upon the util version. +uint64_t Log2Ceiling(uint64_t n) { + int floor = (absl::bit_width(n) - 1); + return floor + (((n & (n - 1)) == 0) ? 0 : 1); +} + +// Helper class used internally by the exposed API to bind a type map and avoid +// passing it around. +class ConstantFoldingContext { + public: + explicit ConstantFoldingContext( + const absl::flat_hash_map& type_map) + : type_map_(type_map) {} + + absl::StatusOr FoldConstants(Expression* expr) { + if (expr->IsLiteral()) { + return expr; + } + if (auto* ref = dynamic_cast(expr); + ref && ref->parameter()->rhs()) { + return FoldConstants(ref->parameter()->rhs()); + } + if (auto* ref = dynamic_cast(expr); + ref && ref->member()->rhs()) { + return FoldConstants(ref->member()->rhs()); + } + if (auto* op = dynamic_cast(expr); op) { + return FoldBinaryOp(op); + } + if (auto* ternary = dynamic_cast(expr); ternary) { + absl::StatusOr test_value = FoldEntireExpr(ternary->test()); + absl::StatusOr consequent_value = + FoldEntireExpr(ternary->consequent()); + absl::StatusOr alternate_value = + FoldEntireExpr(ternary->alternate()); + if (test_value.ok() && consequent_value.ok() && alternate_value.ok()) { + return *test_value ? MakeFoldedConstant(ternary, *consequent_value) + : MakeFoldedConstant(ternary, *alternate_value); + } + } + if (auto* call = dynamic_cast(expr); call) { + if (call->name() == "clog2" && call->args().has_value() && + (*call->args()).size() == 1) { + absl::StatusOr arg_value = FoldEntireExpr((*call->args())[0]); + if (arg_value.ok()) { + return MakeFoldedConstant(call, Log2Ceiling(*arg_value)); + } + } + } + return expr; + } + + absl::StatusOr FoldEntireExpr(Expression* expr) { + XLS_ASSIGN_OR_RETURN(Expression * folded, FoldConstants(expr)); + if (folded->IsLiteral()) { + return folded->AsLiteralOrDie()->ToInt64(); + } + return absl::InvalidArgumentError( + absl::StrCat("Expression does not entirely fold to a constant: ", + expr->Emit(nullptr))); + } + + absl::StatusOr MakeFoldedConstant(Expression* original, + int64_t value) { + int64_t effective_size = 32; + bool is_signed = true; + const auto it = type_map_.find(original); + if (it != type_map_.end()) { + XLS_ASSIGN_OR_RETURN(effective_size, it->second->FlatBitCountAsInt64()); + is_signed = it->second->is_signed(); + } + return original->file()->Make( + original->loc(), + is_signed ? SBits(value, effective_size) + : UBits(static_cast(value), effective_size), + value < 0 ? FormatPreference::kHex : FormatPreference::kUnsignedDecimal, + /*effective_bit_count=*/effective_size, + /*emit_bit_count=*/effective_size != 32, + /*declared_as_signed=*/is_signed); + } + + absl::StatusOr FoldConstants(DataType* data_type) { + if (data_type->FlatBitCountAsInt64().ok()) { + return data_type; + } + if (auto* bit_vector_type = dynamic_cast(data_type); + bit_vector_type && !bit_vector_type->size_expr()->IsLiteral()) { + XLS_ASSIGN_OR_RETURN(int64_t folded_size, + FoldEntireExpr(bit_vector_type->size_expr())); + return data_type->file()->Make( + data_type->loc(), + data_type->file()->PlainLiteral(static_cast(folded_size), + data_type->loc()), + data_type->is_signed(), + /*size_expr_is_max=*/bit_vector_type->max().has_value()); + } + if (auto* array_type = dynamic_cast(data_type); + array_type) { + XLS_ASSIGN_OR_RETURN(std::vector folded_dims, + FoldDims(array_type->dims())); + XLS_ASSIGN_OR_RETURN(DataType * folded_element_type, + FoldConstants(array_type->element_type())); + return array_type->file()->template Make( + data_type->loc(), folded_element_type, folded_dims); + } + if (auto* array_type = dynamic_cast(data_type); + array_type) { + XLS_ASSIGN_OR_RETURN(std::vector folded_dims, + FoldDims(array_type->dims())); + XLS_ASSIGN_OR_RETURN(DataType * folded_element_type, + FoldConstants(array_type->element_type())); + return array_type->file()->template Make( + data_type->loc(), folded_element_type, folded_dims, + /*dims_are_max=*/array_type->dims_are_max()); + } + if (auto* enum_def = dynamic_cast(data_type); enum_def) { + XLS_ASSIGN_OR_RETURN(DataType * folded_base_type, + FoldConstants(enum_def->BaseType())); + return data_type->file()->Make(enum_def->loc(), enum_def->kind(), + folded_base_type, + enum_def->members()); + } + if (auto* struct_def = dynamic_cast(data_type); struct_def) { + XLS_ASSIGN_OR_RETURN(std::vector folded_member_defs, + FoldTypesOfDefs(struct_def->members())); + return struct_def->file()->Make(struct_def->loc(), + folded_member_defs); + } + if (auto* type_def = dynamic_cast(data_type); type_def) { + return FoldConstants(type_def->BaseType()); + } + return absl::InternalError(absl::StrCat("Could not constant-fold type: ", + data_type->Emit(nullptr))); + } + + private: + absl::StatusOr> FoldDims( + absl::Span dims) { + std::vector result(dims.size()); + result.reserve(dims.size()); + int i = 0; + for (Expression* dim : dims) { + XLS_ASSIGN_OR_RETURN(result[i++], FoldEntireExpr(dim)); + } + return result; + } + + absl::StatusOr> FoldTypesOfDefs( + absl::Span defs) { + std::vector result(defs.size()); + int i = 0; + for (Def* def : defs) { + XLS_ASSIGN_OR_RETURN(DataType * folded_type, + FoldConstants(def->data_type())); + result[i++] = def->file()->Make( + def->loc(), def->GetName(), def->data_kind(), folded_type, + def->init().has_value() ? *def->init() : nullptr); + } + return result; + } + + absl::StatusOr FoldBinaryOp(Operator* op) { + auto* binop = dynamic_cast(op); + XLS_ASSIGN_OR_RETURN(Expression * folded_lhs, FoldConstants(binop->lhs())); + XLS_ASSIGN_OR_RETURN(Expression * folded_rhs, FoldConstants(binop->rhs())); + if (!folded_lhs->IsLiteral() || !folded_rhs->IsLiteral()) { + return op->file()->Make(op->loc(), folded_lhs, folded_rhs, + op->kind()); + } + Literal* lhs_literal = folded_lhs->AsLiteralOrDie(); + Literal* rhs_literal = folded_rhs->AsLiteralOrDie(); + XLS_ASSIGN_OR_RETURN(int64_t lhs_value, + folded_lhs->AsLiteralOrDie()->ToInt64()); + XLS_ASSIGN_OR_RETURN(int64_t rhs_value, + folded_rhs->AsLiteralOrDie()->ToInt64()); + bool signed_input = lhs_literal->is_declared_as_signed() && + rhs_literal->is_declared_as_signed(); + std::optional bool_result; + std::optional int_result; + switch (op->kind()) { + case OperatorKind::kAdd: + int_result = lhs_value + rhs_value; + break; + case OperatorKind::kSub: + int_result = lhs_value - rhs_value; + break; + case OperatorKind::kMul: + int_result = lhs_value * rhs_value; + break; + case OperatorKind::kDiv: + int_result = lhs_value / rhs_value; + break; + case OperatorKind::kMod: + int_result = lhs_value % rhs_value; + break; + case OperatorKind::kEq: + bool_result = lhs_value == rhs_value; + break; + case OperatorKind::kNe: + bool_result = lhs_value != rhs_value; + break; + case OperatorKind::kGe: + bool_result = signed_input ? lhs_value >= rhs_value + : static_cast(lhs_value) >= + static_cast(rhs_value); + break; + case OperatorKind::kGt: + bool_result = signed_input ? lhs_value > rhs_value + : static_cast(lhs_value) > + static_cast(rhs_value); + break; + case OperatorKind::kLe: + bool_result = signed_input ? lhs_value <= rhs_value + : static_cast(lhs_value) <= + static_cast(rhs_value); + break; + case OperatorKind::kLt: + bool_result = signed_input ? lhs_value < rhs_value + : static_cast(lhs_value) < + static_cast(rhs_value); + break; + default: + break; + } + if (int_result.has_value()) { + return MakeFoldedConstant(op, *int_result); + } + if (bool_result.has_value()) { + return MakeFoldedConstant(op, static_cast(*bool_result)); + } + return op->file()->Make(op->loc(), folded_lhs, folded_rhs, + op->kind()); + } + + const absl::flat_hash_map& type_map_; +}; + +} // namespace + +absl::StatusOr FoldVastConstants( + DataType* data_type, + const absl::flat_hash_map& type_map) { + auto context = std::make_unique(type_map); + return context->FoldConstants(data_type); +} + +absl::StatusOr FoldVastConstants( + Expression* expr, + const absl::flat_hash_map& type_map) { + auto context = std::make_unique(type_map); + return context->FoldConstants(expr); +} + +} // namespace verilog +} // namespace xls diff --git a/xls/codegen/fold_vast_constants.h b/xls/codegen/fold_vast_constants.h new file mode 100644 index 0000000000..af3f2602f2 --- /dev/null +++ b/xls/codegen/fold_vast_constants.h @@ -0,0 +1,48 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef XLS_CODEGEN_FOLD_VAST_CONSTANTS_H_ +#define XLS_CODEGEN_FOLD_VAST_CONSTANTS_H_ + +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "xls/codegen/vast.h" + +namespace xls { +namespace verilog { + +// Returns `expr` with the constants in it folded to the extent currently +// supported. The given `type_map` specifies the inferred types of +// sub-expressions (see `InferVastTypes()`); this should be provided if maximum +// bitwise accuracy to Verilog is desired. +// +// This function only fails due to unexpected conditions; inability to fold +// `expr` entirely into one `Literal`, or do anything at all, is not considered +// failure. The folding logic currently supported is focused on what is needed +// for typical data type folding. +absl::StatusOr FoldVastConstants( + Expression* expr, + const absl::flat_hash_map& type_map = {}); + +// Overload that folds the expressions within a `DataType` specification, to the +// point where the returned `DataType` can produce a flat bit count. If folding +// to that extent is not possible, this function will return an error. +absl::StatusOr FoldVastConstants( + DataType* data_type, + const absl::flat_hash_map& type_map = {}); + +} // namespace verilog +} // namespace xls + +#endif // XLS_CODEGEN_FOLD_VAST_CONSTANTS_H_ diff --git a/xls/codegen/fold_vast_constants_test.cc b/xls/codegen/fold_vast_constants_test.cc new file mode 100644 index 0000000000..5303c54b2b --- /dev/null +++ b/xls/codegen/fold_vast_constants_test.cc @@ -0,0 +1,334 @@ +// Copyright 2024 The XLS Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "xls/codegen/fold_vast_constants.h" + +#include +#include +#include +#include +#include + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "absl/log/check.h" +#include "absl/status/statusor.h" +#include "xls/codegen/vast.h" +#include "xls/common/status/matchers.h" +#include "xls/common/status/status_macros.h" +#include "xls/ir/bits.h" +#include "xls/ir/bits_ops.h" +#include "xls/ir/format_preference.h" +#include "xls/ir/number_parser.h" +#include "xls/ir/source_location.h" + +namespace xls { +namespace verilog { +namespace { + +using ::xls::status_testing::IsOkAndHolds; + +class FoldVastConstantsTest : public ::testing::Test { + public: + FoldVastConstantsTest() : file_(FileType::kSystemVerilog) { + module_ = file_.AddModule("test_module", SourceInfo()); + } + + Literal* BareLiteral(int32_t value) { + absl::StatusOr bits = ParseNumber(std::to_string(value)); + CHECK_OK(bits); + return file_.Make(SourceInfo(), + value < 0 ? bits_ops::SignExtend(*bits, 32) + : bits_ops::ZeroExtend(*bits, 32), + FormatPreference::kDefault, + /*declared_bit_count=*/32, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true); + } + + std::string FoldConstantsToString(Expression* expr) { + absl::StatusOr folded = FoldVastConstants(expr); + XLS_EXPECT_OK(folded); + return (*folded)->Emit(nullptr); + } + + absl::StatusOr FoldConstantsAndGetBitCount(DataType* data_type) { + XLS_ASSIGN_OR_RETURN(DataType * folded, FoldVastConstants(data_type)); + return folded->FlatBitCountAsInt64(); + } + + VerilogFile file_; + Module* module_; +}; + +TEST_F(FoldVastConstantsTest, FoldLiteral) { + SourceInfo loc; + EXPECT_EQ(FoldConstantsToString(BareLiteral(2)), "2"); +} + +TEST_F(FoldVastConstantsTest, FoldLiteralArithmetic) { + SourceInfo loc; + EXPECT_EQ( + FoldConstantsToString(file_.Add(BareLiteral(2), BareLiteral(3), loc)), + "5"); + EXPECT_EQ( + FoldConstantsToString(file_.Sub(BareLiteral(5), BareLiteral(3), loc)), + "2"); + EXPECT_EQ( + FoldConstantsToString(file_.Sub(BareLiteral(2), BareLiteral(3), loc)), + "32'shffff_ffff"); + EXPECT_EQ( + FoldConstantsToString(file_.Mul( + BareLiteral(4), file_.Add(BareLiteral(2), BareLiteral(3), loc), loc)), + "20"); + EXPECT_EQ(FoldConstantsToString( + file_.Div(BareLiteral(24), + file_.Mul(BareLiteral(4), BareLiteral(3), loc), loc)), + "2"); + EXPECT_EQ(FoldConstantsToString( + file_.Mod(BareLiteral(24), + file_.Mul(BareLiteral(5), BareLiteral(2), loc), loc)), + "4"); +} + +TEST_F(FoldVastConstantsTest, FoldLiteralComparison) { + SourceInfo loc; + EXPECT_EQ( + FoldConstantsToString(file_.Equals(BareLiteral(1), BareLiteral(1), loc)), + "1"); + EXPECT_EQ( + FoldConstantsToString(file_.Equals(BareLiteral(1), BareLiteral(2), loc)), + "0"); + EXPECT_EQ(FoldConstantsToString( + file_.NotEquals(BareLiteral(1), BareLiteral(1), loc)), + "0"); + EXPECT_EQ(FoldConstantsToString( + file_.NotEquals(BareLiteral(1), BareLiteral(2), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.GreaterThan(BareLiteral(5), BareLiteral(3), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.GreaterThan(BareLiteral(3), BareLiteral(5), loc)), + "0"); + EXPECT_EQ(FoldConstantsToString( + file_.LessThan(BareLiteral(5), BareLiteral(3), loc)), + "0"); + EXPECT_EQ(FoldConstantsToString( + file_.LessThan(BareLiteral(3), BareLiteral(5), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString(file_.GreaterThan( + BareLiteral(std::numeric_limits::max()), + BareLiteral(std::numeric_limits::min()), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString(file_.GreaterThan( + BareLiteral(std::numeric_limits::max()), + BareLiteral(0), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString(file_.LessThan( + BareLiteral(std::numeric_limits::min()), + BareLiteral(std::numeric_limits::min() + 1), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.LessThanEquals(BareLiteral(1), BareLiteral(1), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.LessThanEquals(BareLiteral(1), BareLiteral(0), loc)), + "0"); + EXPECT_EQ(FoldConstantsToString( + file_.LessThanEquals(BareLiteral(0), BareLiteral(1), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.GreaterThanEquals(BareLiteral(1), BareLiteral(1), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.GreaterThanEquals(BareLiteral(1), BareLiteral(0), loc)), + "1"); + EXPECT_EQ(FoldConstantsToString( + file_.GreaterThanEquals(BareLiteral(0), BareLiteral(1), loc)), + "0"); + // Weird-sized values not created with the convenience constructor. + Bits three_ones = Bits::AllOnes(3); + Bits two_of_three_ones(3); + two_of_three_ones.SetRange(0, 2, true); + // Try as unsigned. + EXPECT_EQ( + FoldConstantsToString(file_.GreaterThan( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + loc)), + "1"); + EXPECT_EQ( + FoldConstantsToString(file_.GreaterThanEquals( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + loc)), + "1"); + EXPECT_EQ( + FoldConstantsToString(file_.LessThan( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + loc)), + "0"); + EXPECT_EQ( + FoldConstantsToString(file_.LessThanEquals( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/false), + loc)), + "0"); + // Try as signed. + EXPECT_EQ( + FoldConstantsToString(file_.GreaterThan( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + loc)), + "0"); + EXPECT_EQ( + FoldConstantsToString(file_.GreaterThanEquals( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + loc)), + "0"); + EXPECT_EQ( + FoldConstantsToString(file_.LessThan( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + loc)), + "1"); + EXPECT_EQ( + FoldConstantsToString(file_.LessThanEquals( + file_.Make(loc, three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + file_.Make(loc, two_of_three_ones, FormatPreference::kHex, 3, + /*emit_bit_count=*/true, + /*declared_as_signed=*/true), + loc)), + "1"); +} + +TEST_F(FoldVastConstantsTest, FoldLiteralAndParameter) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", BareLiteral(3), loc); + EXPECT_EQ(FoldConstantsToString(file_.Mul( + BareLiteral(4), file_.Add(BareLiteral(2), foo, loc), loc)), + "20"); +} + +TEST_F(FoldVastConstantsTest, FoldLiteralAndEnumValue) { + SourceInfo loc; + Enum* enum_def = + file_.Make(loc, DataKind::kLogic, file_.BitVectorType(16, loc)); + auto* foo = enum_def->AddMember("foo", BareLiteral(3), loc); + EXPECT_EQ(FoldConstantsToString(file_.Mul( + BareLiteral(4), file_.Add(BareLiteral(2), foo, loc), loc)), + "20"); +} + +TEST_F(FoldVastConstantsTest, FoldTernary) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", BareLiteral(5), loc); + EXPECT_EQ(FoldConstantsToString( + file_.Ternary(file_.GreaterThan(foo, BareLiteral(7), loc), + file_.Add(foo, BareLiteral(1), loc), + file_.Mul(foo, BareLiteral(2), loc), loc)), + "10"); +} + +TEST_F(FoldVastConstantsTest, FoldComplexBitVectorSpec) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", BareLiteral(3), loc); + EXPECT_THAT( + FoldConstantsAndGetBitCount(file_.Make( + loc, + file_.Mul(BareLiteral(4), file_.Add(BareLiteral(2), foo, loc), loc), + /*is_signed=*/false, /*size_expr_is_max=*/false)), + IsOkAndHolds(20)); +} + +TEST_F(FoldVastConstantsTest, FoldEnumBaseType) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", BareLiteral(3), loc); + EXPECT_THAT(FoldConstantsAndGetBitCount(file_.Make( + loc, DataKind::kLogic, + file_.Make( + loc, + file_.Mul(BareLiteral(4), + file_.Add(BareLiteral(2), foo, loc), loc), + /*is_signed=*/false, /*size_expr_is_max=*/false))), + IsOkAndHolds(20)); +} + +TEST_F(FoldVastConstantsTest, FoldComplexPackedArraySpec) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", BareLiteral(3), loc); + EXPECT_THAT(FoldConstantsAndGetBitCount(file_.Make( + loc, + file_.Make( + loc, + file_.Mul(BareLiteral(4), + file_.Add(BareLiteral(2), foo, loc), loc), + /*is_signed=*/false, /*size_expr_is_max=*/false), + std::vector{file_.Sub(foo, BareLiteral(1), loc)}, + /*dims_are_max=*/true)), + IsOkAndHolds(40)); +} + +TEST_F(FoldVastConstantsTest, FoldStructDef) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", BareLiteral(3), loc); + auto* bit_vector_type = file_.Make( + loc, file_.Mul(BareLiteral(4), file_.Add(BareLiteral(2), foo, loc), loc), + /*is_signed=*/false, /*size_expr_is_max=*/false); + EXPECT_THAT( + FoldConstantsAndGetBitCount(file_.Make( + loc, + std::vector{file_.Make(loc, "member1", DataKind::kLogic, + bit_vector_type), + file_.Make(loc, "member2", DataKind::kLogic, + bit_vector_type)})), + IsOkAndHolds(40)); +} + +} // namespace +} // namespace verilog +} // namespace xls diff --git a/xls/codegen/infer_vast_types.cc b/xls/codegen/infer_vast_types.cc index 4c2c29b8bd..04ff6fffc7 100644 --- a/xls/codegen/infer_vast_types.cc +++ b/xls/codegen/infer_vast_types.cc @@ -23,13 +23,16 @@ #include #include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/types/span.h" +#include "xls/codegen/fold_vast_constants.h" #include "xls/codegen/vast.h" #include "xls/common/status/status_macros.h" +#include "xls/ir/source_location.h" namespace xls { namespace verilog { @@ -43,9 +46,10 @@ absl::flat_hash_map BuildSystemFunctionReturnTypes( VerilogFile* file) { DataType* int_type = file->IntegerType(SourceInfo()); DataType* scalar_type = file->ScalarType(SourceInfo()); - return {{"clog2", int_type}, {"countbits", int_type}, - {"countones", int_type}, {"onehot", scalar_type}, - {"onehot0", scalar_type}, {"isunknown", scalar_type}}; + return {{"bits", int_type}, {"clog2", int_type}, + {"countbits", int_type}, {"countones", int_type}, + {"onehot", scalar_type}, {"onehot0", scalar_type}, + {"isunknown", scalar_type}}; } // A utility that helps build a map of inferred types. @@ -107,7 +111,7 @@ class TypeInferenceVisitor { // promote the internals of RHS, while reflecting the fact that the top node // must be downcast to fit the variable. if (external_context_type.has_value()) { - (*types_)[expr] = *external_context_type; + (*types_)[expr] = MaybeFoldConstants(*external_context_type); } return absl::OkStatus(); } @@ -131,23 +135,68 @@ class TypeInferenceVisitor { return absl::OkStatus(); } - absl::Status TraverseTypedef(Typedef* type_def) { - if (auto* enum_def = dynamic_cast(type_def->data_type()); enum_def) { + absl::Status TraverseDataType(DataType* data_type) { + if (auto* bit_vector_type = dynamic_cast(data_type); + bit_vector_type && !bit_vector_type->size_expr()->IsLiteral()) { + return TraverseExpression(bit_vector_type->size_expr()); + } + if (auto* array_type = dynamic_cast(data_type); + array_type) { + XLS_RETURN_IF_ERROR(TraverseDataType(array_type->element_type())); + for (Expression* dim : array_type->dims()) { + if (!dim->IsLiteral()) { + XLS_RETURN_IF_ERROR(TraverseExpression(dim)); + } + } + } + if (auto* enum_def = dynamic_cast(data_type); enum_def) { return TraverseEnum(enum_def); } + if (auto* struct_def = dynamic_cast(data_type); struct_def) { + return TraverseStruct(struct_def); + } + return absl::OkStatus(); + } + + absl::Status TraverseStruct(Struct* struct_def) { + for (Def* def : struct_def->members()) { + XLS_RETURN_IF_ERROR(TraverseDataType(def->data_type())); + } return absl::OkStatus(); } + absl::Status TraverseTypedef(Typedef* type_def) { + return TraverseDataType(type_def->data_type()); + } + absl::Status TraverseParameter(Parameter* parameter) { std::optional data_type = std::nullopt; if (parameter->def()) { data_type = parameter->def()->data_type(); + XLS_RETURN_IF_ERROR(TraverseDataType(*data_type)); } - return TraverseExpression(parameter->rhs(), - /*external_context_type=*/data_type); + XLS_RETURN_IF_ERROR( + TraverseExpression(parameter->rhs(), + /*external_context_type=*/data_type)); + if (!parameter->def()) { + // For parameters of the form `parameter foo = some_expr;`, without a type + // on the LHS, the RHS decides the type, and there has to be an RHS. We + // store this type in the separate `auto_parameter_types_` map, because + // the normal type map can only have exprs as keys (and clients of + // inference only care about expr types). + const auto it = types_->find(parameter->rhs()); + if (it == types_->end()) { + return absl::InvalidArgumentError( + absl::StrCat("No type could be inferred for untyped parameter: ", + parameter->Emit(nullptr))); + } + auto_parameter_types_.emplace(parameter, it->second); + } + return absl::OkStatus(); } absl::Status TraverseEnum(Enum* enum_def) { + XLS_RETURN_IF_ERROR(TraverseDataType(enum_def->BaseType())); for (EnumMember* member : enum_def->members()) { if (member->rhs() != nullptr) { XLS_RETURN_IF_ERROR(TraverseExpression( @@ -273,12 +322,17 @@ class TypeInferenceVisitor { } } if (auto* ref = dynamic_cast(expr); ref) { - return ref->parameter()->def() == nullptr - ? ref->file()->IntegerType(ref->loc()) - : ref->parameter()->def()->data_type(); + if (ref->parameter()->def() != nullptr) { + return ref->parameter()->def()->data_type(); + } + const auto it = auto_parameter_types_.find(ref->parameter()); + if (it != auto_parameter_types_.end()) { + return it->second; + } + return ref->file()->IntegerType(ref->loc()); } if (auto* ref = dynamic_cast(expr); ref) { - return ref->enum_def()->BaseType(); + return ref->enum_def(); } if (auto* ref = dynamic_cast(expr); ref) { return ref->def()->data_type(); @@ -303,7 +357,7 @@ class TypeInferenceVisitor { arg->Emit(nullptr))); } XLS_ASSIGN_OR_RETURN(int64_t arg_bit_count, - it->second->FlatBitCountAsInt64()); + EvaluateBitCount(it->second)); bit_count += arg_bit_count; } return expr->file()->BitVectorType(bit_count, expr->loc(), @@ -392,25 +446,40 @@ class TypeInferenceVisitor { return; } if (auto* ternary = dynamic_cast(expr); ternary) { - types_->emplace(ternary->consequent(), data_type); - types_->emplace(ternary->alternate(), data_type); + ApplyInferredTypeRecursively(data_type, ternary->consequent()); + ApplyInferredTypeRecursively(data_type, ternary->alternate()); } } absl::StatusOr LargestType(DataType* a, DataType* b, bool reconcile_signedness = true) { - XLS_ASSIGN_OR_RETURN(int64_t a_bit_count, a->FlatBitCountAsInt64()); - XLS_ASSIGN_OR_RETURN(int64_t b_bit_count, b->FlatBitCountAsInt64()); + DataType* folded_a = MaybeFoldConstants(a); + DataType* folded_b = MaybeFoldConstants(b); + absl::StatusOr maybe_a_bit_count = folded_a->FlatBitCountAsInt64(); + absl::StatusOr maybe_b_bit_count = folded_b->FlatBitCountAsInt64(); + // In cases like `parameter logic[$clog2(32767):0] = ...`, we don't have the + // ability to infer one of the types, and it's unlikely to matter. + if (!maybe_b_bit_count.ok()) { + return folded_a; + } + if (!maybe_a_bit_count.ok()) { + return folded_b; + } + int64_t a_bit_count = *maybe_a_bit_count; + int64_t b_bit_count = *maybe_b_bit_count; bool b_int = dynamic_cast(b) != nullptr; int64_t result_bit_count; DataType* result; - // Prefer the larger type, but if they are equivalent, prefer the integer - // type, if any. - if (a_bit_count > b_bit_count || (a_bit_count == b_bit_count && !b_int)) { - result = a; + // Prefer the larger type, but if they are equivalent: + // 1. Prefer the integer type, if any, as it's more precise as to intent. + // 2. Prefer the RHS in a case of sign mismatch without reconciliation. + if (a_bit_count > b_bit_count || + (a_bit_count == b_bit_count && a->is_signed() == b->is_signed() && + !b_int)) { + result = folded_a; result_bit_count = a_bit_count; } else { - result = b; + result = folded_b; result_bit_count = b_bit_count; } // Don't propagate user-defined types where the user didn't use them. @@ -430,11 +499,39 @@ class TypeInferenceVisitor { return data_type->file()->Make(data_type->loc(), /*is_signed=*/false); } - XLS_ASSIGN_OR_RETURN(int64_t bit_count, data_type->FlatBitCountAsInt64()); + XLS_ASSIGN_OR_RETURN(int64_t bit_count, EvaluateBitCount(data_type)); return data_type->file()->BitVectorType(bit_count, data_type->loc()); } + DataType* MaybeFoldConstants(DataType* data_type) { + if (!data_type->FlatBitCountAsInt64().ok()) { + absl::StatusOr folded_type = + FoldVastConstants(data_type, *types_); + if (folded_type.ok()) { + return *folded_type; + } + VLOG(2) << "Could not fold: " << data_type->Emit(nullptr) + << ", status: " << folded_type.status(); + } + return data_type; + } + + absl::StatusOr EvaluateBitCount(DataType* data_type) { + absl::StatusOr direct_answer = data_type->FlatBitCountAsInt64(); + if (direct_answer.ok()) { + return direct_answer; + } + absl::StatusOr folded_type = + FoldVastConstants(data_type, *types_); + if (folded_type.ok()) { + return (*folded_type)->FlatBitCountAsInt64(); + } + return absl::InvalidArgumentError(absl::StrCat( + "Could not evaluate bit count for type: ", data_type->Emit(nullptr))); + } + absl::flat_hash_map* types_; + absl::flat_hash_map auto_parameter_types_; const absl::flat_hash_map system_function_return_types_; }; @@ -450,6 +547,20 @@ absl::StatusOr> InferVastTypes( return types; } +absl::StatusOr> InferVastTypes( + absl::Span corpus) { + absl::flat_hash_map types; + std::unique_ptr visitor; + for (VerilogFile* file : corpus) { + if (visitor == nullptr) { + visitor = std::make_unique( + &types, BuildSystemFunctionReturnTypes(file)); + } + XLS_RETURN_IF_ERROR(visitor->TraverseFile(file)); + } + return types; +} + absl::StatusOr> InferVastTypes( Expression* expr) { absl::flat_hash_map types; diff --git a/xls/codegen/infer_vast_types.h b/xls/codegen/infer_vast_types.h index d4284721dd..5694895372 100644 --- a/xls/codegen/infer_vast_types.h +++ b/xls/codegen/infer_vast_types.h @@ -17,6 +17,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/status/statusor.h" +#include "absl/types/span.h" #include "xls/codegen/vast.h" namespace xls { @@ -30,6 +31,12 @@ namespace verilog { absl::StatusOr> InferVastTypes( VerilogFile* file); +// Builds a map of the inferred types of all expressions in the given `corpus`. +// If a `VerilogFile` may have cross-file references, it is best to use this +// function to build a unified map of all types in the corpus. +absl::StatusOr> InferVastTypes( + absl::Span corpus); + // Builds a map of the inferred types for just `expr` and its descendants. This // should not be used to analyze a whole file in steps. absl::StatusOr> InferVastTypes( diff --git a/xls/codegen/infer_vast_types_test.cc b/xls/codegen/infer_vast_types_test.cc index 0c2a23efed..f8400dd4b8 100644 --- a/xls/codegen/infer_vast_types_test.cc +++ b/xls/codegen/infer_vast_types_test.cc @@ -240,6 +240,8 @@ TEST_F(InferVastTypesTest, TernaryExampleFromSpec) { EXPECT_EQ( InferTypesToString(file_.Ternary(c, file_.BitwiseAnd(a, b, loc), d, loc)), R"( +a : [4:0] +b : [4:0] c : [3:0] d : [4:0] a & b : [4:0] @@ -247,6 +249,40 @@ c ? a & b : d : [4:0] )"); } +TEST_F(InferVastTypesTest, ComplexTernary) { + // (a > b) ? a : ((b < c) ? b : c) + // where a is 3 bits, b is 4 bits, and c is 5 bits. The point here is mainly + // to make sure all the refs to each variable get their own inferred type. + SourceInfo loc; + auto* a = module_->AddParameter( + file_.Make(loc, "a", DataKind::kLogic, file_.BitVectorType(3, loc)), + UnsignedZero(3), loc); + auto* b = module_->AddParameter( + file_.Make(loc, "b", DataKind::kLogic, file_.BitVectorType(4, loc)), + UnsignedZero(4), loc); + auto* c = module_->AddParameter( + file_.Make(loc, "c", DataKind::kLogic, file_.BitVectorType(5, loc)), + UnsignedZero(5), loc); + EXPECT_EQ(InferTypesToString(file_.Ternary( + file_.GreaterThan(a, b, loc), a->Duplicate(), + file_.Ternary(file_.LessThan(b->Duplicate(), c, loc), + b->Duplicate(), c->Duplicate(), loc), + loc)), + R"( +a : [3:0] +a : [4:0] +b : [3:0] +b : [4:0] +b : [4:0] +c : [4:0] +c : [4:0] +a > b : logic +b < c : logic +a > b ? a : (b < c ? b : c) : [4:0] +b < c ? b : c : [4:0] +)"); +} + TEST_F(InferVastTypesTest, SimpleMultiplicationFromSpec) { SourceInfo loc; auto* a = module_->AddParameter( @@ -309,6 +345,28 @@ a ** b : [15:0] )"); } +TEST_F(InferVastTypesTest, MultiFile) { + SourceInfo loc; + VerilogFile file2(FileType::kSystemVerilog); + Module* module2 = file2.AddModule("module2", loc); + auto* a = module_->AddParameter( + file_.Make(loc, "a", DataKind::kLogic, file_.BitVectorType(4, loc)), + UnsignedZero(4), loc); + module2->AddParameter( + file2.Make(loc, "b", DataKind::kLogic, file2.BitVectorType(6, loc)), + file2.Add(file2.PlainLiteral(50, loc), a, loc), loc); + std::vector files = {&file_, &file2}; + auto types = InferVastTypes(files); + XLS_ASSERT_OK(types); + EXPECT_EQ(TypesToString(*types), + R"( +0 : [3:0] +50 : unsigned +a : unsigned +50 + a : [5:0] +)"); +} + TEST_F(InferVastTypesTest, BigConstants) { SourceInfo loc; // parameter int unsigned GiB = 1024 * 1024 * 1024; @@ -344,6 +402,36 @@ GiB : [63:0] )"); } +TEST_F(InferVastTypesTest, NonLiteralArrayDim) { + SourceInfo loc; + // parameter foo = 24; + // parameter logic[foo - 1:0] bar = 55; + // parameter logic[63:0] baz = bar + 1; + auto* foo = module_->AddParameter("foo", BareLiteral(24), loc); + auto* bar = module_->AddParameter( + file_.Make(loc, "bar", DataKind::kLogic, + file_.Make( + loc, file_.Sub(foo, BareLiteral(1), loc), + /*is_signed=*/false, /*size_expr_is_max=*/true)), + BareLiteral(55), loc); + module_->AddParameter( + file_.Make( + loc, "baz", DataKind::kLogic, + file_.Make(loc, BareLiteral(63), /*is_signed=*/false, + /*size_expr_is_max=*/true)), + file_.Add(bar, BareLiteral(1), loc), loc); + EXPECT_EQ(InferTypesToString(), R"( +1 : [63:0] +1 : integer +24 : integer +55 : [23:0] +bar : [63:0] +foo : integer +bar + 1 : [63:0] +foo - 1 : integer +)"); +} + TEST_F(InferVastTypesTest, FunctionCall) { // function automatic logic[15:0] fn( // logic[7:0] a, @@ -424,6 +512,72 @@ foo + 1 : unsigned )"); } +TEST_F(InferVastTypesTest, PromotionToFoldedTypedef) { + // parameter size = 50; + // typedef logic[size - 1:0] foo_t; + // parameter mb = 1024 * 1024; + // parameter foo_t bar = 50 * mb; + SourceInfo loc; + auto* size = module_->AddParameter("size", BareLiteral(50), loc); + Typedef* type_def = file_.Make( + loc, + file_.Make(loc, "foo_t", DataKind::kLogic, + file_.Make( + loc, file_.Sub(size, BareLiteral(1), loc), + /*is_signed=*/false, /*size_expr_is_max=*/true))); + module_->AddModuleMember(type_def); + auto* mb = module_->AddParameter( + "mb", file_.Mul(BareLiteral(1024), BareLiteral(1024), loc), loc); + module_->AddParameter(file_.Make(loc, "bar", DataKind::kUser, + file_.Make(loc, type_def)), + file_.Mul(BareLiteral(50), mb, loc), loc); + EXPECT_EQ(InferTypesToString(), + R"( +1 : integer +50 : [49:0] +50 : integer +1024 : integer +1024 : integer +mb : [49:0] +size : integer +50 * mb : [49:0] +size - 1 : integer +1024 * 1024 : integer +)"); +} + +TEST_F(InferVastTypesTest, ConcatWithFoldedTypedef) { + // parameter size = 50; + // typedef logic[size - 1:0] foo_t; + // parameter foo_t foo = 3; + // parameter integer bar = 4; + // {foo, bar} + SourceInfo loc; + auto* size = module_->AddParameter("size", BareLiteral(50), loc); + Typedef* type_def = file_.Make( + loc, + file_.Make(loc, "foo_t", DataKind::kLogic, + file_.Make( + loc, file_.Sub(size, BareLiteral(1), loc), + /*is_signed=*/false, /*size_expr_is_max=*/true))); + module_->AddModuleMember(type_def); + auto* foo = module_->AddParameter( + file_.Make(loc, "foo", DataKind::kUser, + file_.Make(loc, type_def)), + BareLiteral(3), loc); + auto* bar = + module_->AddParameter(file_.Make(loc, "bar", DataKind::kInteger, + file_.Make(loc)), + BareLiteral(4), loc); + EXPECT_EQ( + InferTypesToString(file_.Concat(std::vector{foo, bar}, loc)), + R"( +bar : integer +foo : foo_t +{foo, bar} : [81:0] +)"); +} + TEST_F(InferVastTypesTest, PackedStruct) { // typedef struct packed { // logic foo; @@ -519,6 +673,41 @@ a + 1 : [15:0] )"); } +TEST_F(InferVastTypesTest, TypedefReturnType) { + // parameter width = 24; + // typedef logic[width - 1:0] word_t; + // function automatic word_t fn( + // word_t a); + // return a + 3'0; + // endfunction + // Getting this right requires constant folding the definition of `word_t` and + // promoting the 3-bit value. + SourceInfo loc; + auto* width = module_->AddParameter("width", BareLiteral(24), loc); + Typedef* word_t = module_->AddTypedef( + file_.Make(loc, "word_t", DataKind::kLogic, + file_.Make( + loc, file_.Sub(width, BareLiteral(1), loc), + /*is_signed=*/false, /*size_expr_is_max=*/true)), + loc); + TypedefType* word_t_type = file_.Make(loc, word_t); + VerilogFunction* fn = file_.Make(loc, "fn", word_t_type); + LogicRef* a = fn->AddArgument( + file_.Make(loc, "a", DataKind::kUser, word_t_type), loc); + fn->AddStatement(loc, file_.Add(a, UnsignedZero(3), loc)); + module_->top()->AddModuleMember(fn); + EXPECT_EQ(InferTypesToString(), + R"( +0 : [23:0] +1 : integer +24 : integer +a : [23:0] +a + 0 : [23:0] +width : integer +width - 1 : integer +)"); +} + TEST_F(InferVastTypesTest, ContextDependentUnary) { // b = ~(a + 0) where all parameters are logic [15:0]. SourceInfo loc; @@ -606,6 +795,19 @@ b + c : [31:0] )"); } +TEST_F(InferVastTypesTest, UseOfUntypedParameter) { + SourceInfo loc; + auto* foo = module_->AddParameter("foo", UnsignedZero(4), loc); + module_->AddParameter("bar", file_.Add(UnsignedZero(2), foo, loc), loc); + file_.PlainLiteral(3, SourceInfo()); + EXPECT_EQ(InferTypesToString(), R"( +0 : [3:0] +0 : [3:0] +foo : [3:0] +0 + foo : [3:0] +)"); +} + } // namespace } // namespace verilog } // namespace xls diff --git a/xls/codegen/vast.cc b/xls/codegen/vast.cc index b1506228b2..e27c095839 100644 --- a/xls/codegen/vast.cc +++ b/xls/codegen/vast.cc @@ -622,6 +622,10 @@ ParameterRef* VerilogPackageSection::AddParameter(Def* def, Expression* rhs, return file()->Make(loc, param); } +ParameterRef* ParameterRef::Duplicate() const { + return file()->Make(loc(), parameter()); +} + Literal* Expression::AsLiteralOrDie() { CHECK(IsLiteral()); return static_cast(this); @@ -642,6 +646,10 @@ LogicRef* Expression::AsLogicRefOrDie() { return static_cast(this); } +LogicRef* LogicRef::Duplicate() const { + return file()->Make(loc(), def()); +} + std::string XSentinel::Emit(LineInfo* line_info) const { LineInfoStart(line_info, this); LineInfoEnd(line_info, this); @@ -694,11 +702,13 @@ BitVectorType::BitVectorType(int64_t width, bool is_signed, VerilogFile* file, is_signed_(is_signed) {} absl::StatusOr BitVectorType::WidthAsInt64() const { - if (!size_expr_->IsLiteral() || size_expr_is_max_) { + if (!size_expr_->IsLiteral()) { return absl::FailedPreconditionError("Width is not a literal: " + size_expr_->Emit(nullptr)); } - return size_expr_->AsLiteralOrDie()->bits().ToUint64(); + XLS_ASSIGN_OR_RETURN(int64_t size, + size_expr_->AsLiteralOrDie()->bits().ToUint64()); + return size + (size_expr_is_max_ ? 1 : 0); } absl::StatusOr BitVectorType::FlatBitCountAsInt64() const { @@ -750,6 +760,13 @@ PackedArrayType::PackedArrayType(int64_t width, is_signed), packed_dims, /*dims_are_max=*/false, file, loc) {} +PackedArrayType::PackedArrayType(DataType* element_type, + absl::Span packed_dims, + bool dims_are_max, VerilogFile* file, + const SourceInfo& loc) + : ArrayTypeBase(element_type, packed_dims, /*dims_are_max=*/dims_are_max, + file, loc) {} + absl::StatusOr PackedArrayType::FlatBitCountAsInt64() const { XLS_ASSIGN_OR_RETURN(int64_t bit_count, WidthAsInt64()); for (Expression* dim : dims()) { @@ -847,7 +864,7 @@ IntegerDef::IntegerDef(std::string_view name, VerilogFile* file, IntegerDef::IntegerDef(std::string_view name, DataType* data_type, Expression* init, VerilogFile* file, const SourceInfo& loc) - : Def(name, DataKind::kInteger, file->IntegerType(loc), init, file, loc) {} + : Def(name, DataKind::kInteger, data_type, init, file, loc) {} namespace { @@ -1254,6 +1271,10 @@ EnumMemberRef* Enum::AddMember(std::string_view name, Expression* rhs, return file()->Make(loc, this, members_.back()); } +EnumMemberRef* EnumMemberRef::Duplicate() const { + return file()->Make(loc(), enum_def(), member()); +} + std::string EnumMember::Emit(LineInfo* line_info) const { LineInfoStart(line_info, this); std::string result = absl::StrFormat("%s = %s", name_, rhs_->Emit(line_info)); diff --git a/xls/codegen/vast.h b/xls/codegen/vast.h index caa9cca6af..cec5c66f35 100644 --- a/xls/codegen/vast.h +++ b/xls/codegen/vast.h @@ -41,6 +41,7 @@ #include "absl/types/span.h" #include "xls/codegen/module_signature.pb.h" #include "xls/ir/bits.h" +#include "xls/ir/bits_ops.h" #include "xls/ir/format_preference.h" #include "xls/ir/source_location.h" @@ -296,6 +297,7 @@ class BitVectorType : public DataType { bool IsScalar() const override { return false; } absl::StatusOr WidthAsInt64() const override; absl::StatusOr FlatBitCountAsInt64() const override; + std::optional width() const override { if (size_expr_is_max_) { return std::nullopt; @@ -308,6 +310,11 @@ class BitVectorType : public DataType { } return std::nullopt; } + + // Returns the expression for either the width or max; whichever was supplied + // at construction time. + Expression* size_expr() const { return size_expr_; } + bool is_signed() const override { return is_signed_; } std::string Emit(LineInfo* line_info) const override; @@ -387,6 +394,9 @@ class PackedArrayType : public ArrayTypeBase { VerilogFile* file, const SourceInfo& loc) : ArrayTypeBase(element_type, packed_dims, dims_are_max, file, loc) {} + PackedArrayType(DataType* element_type, absl::Span packed_dims, + bool dims_are_max, VerilogFile* file, const SourceInfo& loc); + absl::StatusOr FlatBitCountAsInt64() const override; std::string Emit(LineInfo* line_info) const override; @@ -1125,6 +1135,11 @@ class EnumMemberRef : public Expression { Enum* enum_def() const { return enum_def_; } EnumMember* member() const { return member_; } + // Duplicates this reference for use in another place. If performing type + // inference on the VAST tree, the same exact ref object should not be used in + // multiple places. + EnumMemberRef* Duplicate() const; + private: Enum* enum_def_; EnumMember* member_; @@ -1137,8 +1152,9 @@ class Enum : public UserDefinedAliasType { const SourceInfo& loc) : UserDefinedAliasType(data_type, file, loc), kind_(kind) {} - Enum(DataKind kind, DataType* data_type, absl::Span members, - VerilogFile* file, const SourceInfo& loc) + Enum(DataKind kind, DataType* data_type, + absl::Span members, VerilogFile* file, + const SourceInfo& loc) : UserDefinedAliasType(data_type, file, loc), kind_(kind), members_(members.begin(), members.end()) {} @@ -1245,6 +1261,11 @@ class ParameterRef : public Expression { Parameter* parameter() const { return parameter_; } + // Duplicates this reference for use in another place. If performing type + // inference on the VAST tree, the same exact ref object should not be used in + // multiple places. + ParameterRef* Duplicate() const; + private: Parameter* parameter_; }; @@ -1275,6 +1296,11 @@ class LogicRef : public IndexableExpression { // Returns the name of the underlying Def this object refers to. std::string GetName() const { return def()->GetName(); } + // Duplicates this reference for use in another place. If performing type + // inference on the VAST tree, the same exact ref object should not be used in + // multiple places. + LogicRef* Duplicate() const; + private: // Logic signal definition. Def* def_; @@ -1512,6 +1538,13 @@ class Literal : public Expression { const Bits& bits() const { return bits_; } + absl::StatusOr ToInt64() const { + if (effective_bit_count() != bits_.bit_count()) { + return bits_ops::ZeroExtend(bits_, effective_bit_count()).ToInt64(); + } + return bits_.ToInt64(); + } + bool IsLiteral() const override { return true; } bool IsLiteralWithValue(int64_t target) const override; diff --git a/xls/dslx/BUILD b/xls/dslx/BUILD index 7b5902d606..5895a54e31 100644 --- a/xls/dslx/BUILD +++ b/xls/dslx/BUILD @@ -193,9 +193,9 @@ cc_test( cc_library( name = "trait_visitor", - srcs = ["trait_visitor.cc"], hdrs = ["trait_visitor.h"], deps = [ + "//xls/common/status:status_macros", "//xls/dslx/frontend:ast", "@com_google_absl//absl/status", ], diff --git a/xls/dslx/trait_visitor.h b/xls/dslx/trait_visitor.h index 32d848b8cf..0a8f65020a 100644 --- a/xls/dslx/trait_visitor.h +++ b/xls/dslx/trait_visitor.h @@ -17,6 +17,7 @@ #include #include "absl/status/status.h" +#include "xls/common/status/status_macros.h" #include "xls/dslx/frontend/ast.h" namespace xls::dslx { @@ -25,7 +26,28 @@ namespace xls::dslx { // Lazily populated as information is needed. class TraitVisitor : public ExprVisitorWithDefault { public: - absl::Status HandleNameRef(const NameRef* expr) override; + template + absl::Status Handle(const T* expr) { + for (AstNode* child : expr->GetChildren(true)) { + if (Expr* child_expr = dynamic_cast(child); child_expr) { + XLS_RETURN_IF_ERROR(child_expr->AcceptExpr(this)); + } + } + return absl::OkStatus(); + } + + template <> + absl::Status Handle(const NameRef* expr) { + name_refs_.push_back(expr); + return absl::OkStatus(); + } + +#define DEFINE_HANDLER(__type) \ + absl::Status Handle##__type(const __type* expr) override { \ + return Handle(expr); \ + } + + XLS_DSLX_EXPR_NODE_EACH(DEFINE_HANDLER) const std::vector& name_refs() { return name_refs_; }