Skip to content

Commit

Permalink
Merge pull request #1774 from xlsynth:cdleary/2024-12-06-fuzz-xn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704973407
  • Loading branch information
copybara-github committed Dec 11, 2024
2 parents 98f2b8d + b7a6e5f commit 1dfce58
Show file tree
Hide file tree
Showing 30 changed files with 736 additions and 253 deletions.
91 changes: 0 additions & 91 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,49 +470,6 @@ static absl::StatusOr<std::unique_ptr<BitsType>> GetTypeOfNodeAsBits(
return std::make_unique<BitsType>(is_signed, bits_like->size);
}

static absl::Status MaybeCheckArrayToBitsCast(const AstNode* node,
const Type* from,
const Type* to) {
const ArrayType* from_array = dynamic_cast<const ArrayType*>(from);
bool to_is_bits_like = IsBitsLike(*to);

if (from_array != nullptr && !to_is_bits_like) {
return absl::InternalError(absl::StrCat(
"The only valid array cast is to bits: ", node->ToString()));
}

if (from_array == nullptr || !to_is_bits_like) {
return absl::OkStatus();
}

// Bits-constructor acts as a bits type, so we don't need to perform
// array-oriented cast checks.
if (IsArrayOfBitsConstructor(*from_array)) {
return absl::OkStatus();
}

// Check casting from an array to bits.
if (from_array->element_type().GetAllDims().size() != 1) {
return absl::InternalError(
"Only casts to/from one-dimensional arrays are supported.");
}

XLS_ASSIGN_OR_RETURN(TypeDim bit_count_dim, from_array->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(int64_t array_bit_count, bit_count_dim.GetAsInt64());

XLS_ASSIGN_OR_RETURN(bit_count_dim, to->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(int64_t bits_bit_count, bit_count_dim.GetAsInt64());

if (array_bit_count != bits_bit_count) {
return absl::InternalError(absl::StrFormat(
"Array-to-bits cast bit counts must match. "
"Saw %d for \"from\" type `%s` vs %d for \"to\" type `%s`.",
array_bit_count, from->ToString(), bits_bit_count, to->ToString()));
}

return absl::OkStatus();
}

static absl::Status MaybeCheckEnumToBitsCast(const AstNode* node,
const Type* from, const Type* to) {
const EnumType* from_enum = dynamic_cast<const EnumType*>(from);
Expand All @@ -526,52 +483,6 @@ static absl::Status MaybeCheckEnumToBitsCast(const AstNode* node,
return absl::OkStatus();
}

static absl::Status MaybeCheckBitsToArrayCast(const AstNode* node,
const Type* from,
const Type* to) {
bool from_is_bits_like = IsBitsLike(*from);
const ArrayType* to_array = dynamic_cast<const ArrayType*>(to);

if (to_array != nullptr && !from_is_bits_like) {
return absl::InternalError(absl::StrCat(
"The only valid array cast is from bits: ", node->ToString()));
}

if (!from_is_bits_like || to_array == nullptr) {
return absl::OkStatus();
}

// Bits-constructor acts as a bits type, so we don't need to perform
// array-oriented cast checks.
if (IsArrayOfBitsConstructor(*to_array)) {
return absl::OkStatus();
}

// Casting from bits to an array.
if (to_array->element_type().GetAllDims().size() != 1) {
return absl::InternalError(
"Only casts to/from one-dimensional arrays are supported.");
}

VLOG(5) << "from_bits: " << from->ToString()
<< " to_array: " << to_array->ToString();

XLS_ASSIGN_OR_RETURN(TypeDim bit_count_dim, from->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(int64_t bits_bit_count, bit_count_dim.GetAsInt64());

XLS_ASSIGN_OR_RETURN(bit_count_dim, to_array->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(int64_t array_bit_count, bit_count_dim.GetAsInt64());

if (array_bit_count != bits_bit_count) {
return absl::InternalError(absl::StrFormat(
"Bits-to-array cast bit counts must match. "
"bits-type `%s` bit count: %d; array-type bit count for `%s`: %d.",
from->ToString(), bits_bit_count, to->ToString(), array_bit_count));
}

return absl::OkStatus();
}

static absl::Status MaybeCheckBitsToEnumCast(const AstNode* node,
const Type* from, const Type* to) {
bool from_is_bits_like = IsBitsLike(*from);
Expand Down Expand Up @@ -613,9 +524,7 @@ absl::Status BytecodeEmitter::HandleCast(const Cast* node) {

XLS_RETURN_IF_ERROR(CheckSupportedCastTypes(node, from));
XLS_RETURN_IF_ERROR(CheckSupportedCastTypes(node, to));
XLS_RETURN_IF_ERROR(MaybeCheckArrayToBitsCast(node, from, to));
XLS_RETURN_IF_ERROR(MaybeCheckEnumToBitsCast(node, from, to));
XLS_RETURN_IF_ERROR(MaybeCheckBitsToArrayCast(node, from, to));
XLS_RETURN_IF_ERROR(MaybeCheckBitsToEnumCast(node, from, to));

bytecode_.push_back(
Expand Down
2 changes: 2 additions & 0 deletions xls/dslx/bytecode/bytecode_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,8 @@ absl::Status BytecodeInterpreter::EvalTrace(const Bytecode& bytecode) {
absl::Status BytecodeInterpreter::EvalWidthSlice(const Bytecode& bytecode) {
XLS_ASSIGN_OR_RETURN(const Type* type, bytecode.type_data());
XLS_ASSIGN_OR_RETURN(const Type* unwrapped_type, UnwrapMetaType(*type));

// Width slice only works on bits-like types.
std::optional<BitsLikeProperties> bits_like = GetBitsLike(*unwrapped_type);
XLS_RET_CHECK(bits_like.has_value())
<< "Wide slice type is not bits-like: " << type->ToString();
Expand Down
40 changes: 9 additions & 31 deletions xls/dslx/constexpr_evaluator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,30 +62,6 @@
#include "xls/ir/bits.h"

namespace xls::dslx {
namespace {

// Fully instantiate the given parametric BitsType using the symbol mappings in
// `env`.
absl::StatusOr<std::unique_ptr<BitsType>> InstantiateParametricNumberType(
const absl::flat_hash_map<std::string, InterpValue>& env,
const BitsType* bits_type) {
ParametricExpression::Env parametric_env;
for (const auto& [k, v] : env) {
parametric_env[k] = v;
}
ParametricExpression::Evaluated e =
bits_type->size().parametric().Evaluate(parametric_env);
if (!std::holds_alternative<InterpValue>(e)) {
return absl::InternalError(
absl::StrCat("Parametric number size did not evaluate to a constant: ",
bits_type->size().ToString()));
}
return std::make_unique<BitsType>(
bits_type->is_signed(),
std::get<InterpValue>(e).GetBitValueViaSign().value());
}

} // namespace

/* static */ absl::Status ConstexprEvaluator::Evaluate(
ImportData* import_data, TypeInfo* type_info,
Expand Down Expand Up @@ -553,13 +529,15 @@ absl::Status ConstexprEvaluator::HandleNumber(const Number* expr) {
XLS_RET_CHECK(tt != nullptr);
type_ptr = tt->wrapped().get();

const BitsType* bt = down_cast<const BitsType*>(type_ptr);
XLS_RET_CHECK(bt != nullptr);
if (bt->size().IsParametric()) {
XLS_ASSIGN_OR_RETURN(temp_type, InstantiateParametricNumberType(
constexpr_env_data.env, bt));
type_ptr = temp_type.get();
}
std::optional<BitsLikeProperties> bits_like = GetBitsLike(*type_ptr);
XLS_RET_CHECK(bits_like.has_value())
<< "Type for number should be bits-like; got: " << type_ptr->ToString();

// Materialize the bits type.
XLS_ASSIGN_OR_RETURN(bool is_signed, bits_like->is_signed.GetAsBool());
XLS_ASSIGN_OR_RETURN(int64_t bit_count, bits_like->size.GetAsInt64());
temp_type = std::make_unique<BitsType>(is_signed, bit_count);
type_ptr = temp_type.get();
} else if (type_ != nullptr) {
type_ptr = type_;
} else if (expr->number_kind() == NumberKind::kBool) {
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/frontend/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ cc_library(
"//xls/common/status:ret_check",
"//xls/common/status:status_macros",
"//xls/dslx:interp_value",
"//xls/ir:number_parser",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/log",
Expand Down
7 changes: 7 additions & 0 deletions xls/dslx/frontend/ast.h
Original file line number Diff line number Diff line change
Expand Up @@ -1005,8 +1005,15 @@ class NameRef : public Expr {
};

enum class NumberKind : uint8_t {
// This kind is used when a keyword `true` or `false` is used as a number
// literal.
kBool,

// This kind is used when a character literal is used as a number literal like
// 'a'.
kCharacter,

// This kind is used for all other number literals.
kOther,
};

Expand Down
54 changes: 54 additions & 0 deletions xls/dslx/frontend/ast_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "xls/dslx/frontend/pos.h"
#include "xls/dslx/frontend/token_utils.h"
#include "xls/dslx/interp_value.h"
#include "xls/ir/number_parser.h"

namespace xls::dslx {
namespace {
Expand Down Expand Up @@ -280,6 +281,32 @@ absl::StatusOr<InterpValue> GetArrayTypeColonAttr(
[&] { return array_type->ToString(); });
}

// Attempts to evaluate an expression as a literal boolean.
//
// This has a few simple forms:
// - `true` / `false`
// - `bool:0x0` / `bool:0x1`
static std::optional<bool> TryEvaluateAsBool(const Expr* expr) {
const Number* number = dynamic_cast<const Number*>(expr);
if (number == nullptr) {
return std::nullopt;
}
if (number->number_kind() == NumberKind::kBool) {
CHECK(number->text() == "true" || number->text() == "false");
return number->text() == "true";
}
if (number->number_kind() == NumberKind::kOther) {
absl::StatusOr<std::pair<bool, Bits>> sm =
GetSignAndMagnitude(number->text());
if (sm.ok() && sm->second.bit_count() <= 1) {
// Note: the zero value is given as a bit count of zero.
bool value = sm->second.bit_count() == 0 ? 0 : sm->second.Get(0);
return std::optional<bool>(value);
}
}
return std::nullopt;
}

std::optional<BitVectorMetadata> ExtractBitVectorMetadata(
const TypeAnnotation* type_annotation) {
bool is_enum = false;
Expand Down Expand Up @@ -332,9 +359,36 @@ std::optional<BitVectorMetadata> ExtractBitVectorMetadata(
return BitVectorMetadata{
.bit_count = bit_count, .is_signed = is_signed.value(), .kind = kind};
}

if (const ArrayTypeAnnotation* array_type =
dynamic_cast<const ArrayTypeAnnotation*>(type);
array_type != nullptr) {
// xN[..] has yet another level of array that annotates the signedness.
if (const ArrayTypeAnnotation* inner_array_type =
dynamic_cast<const ArrayTypeAnnotation*>(
array_type->element_type());
inner_array_type != nullptr) {
const BuiltinTypeAnnotation* maybe_xn =
dynamic_cast<const BuiltinTypeAnnotation*>(
inner_array_type->element_type());
if (maybe_xn != nullptr && maybe_xn->builtin_type() == BuiltinType::kXN) {
// We can only extract the signedness from the inner array dimension if
// it's a bare number.
std::optional<bool> is_signed =
TryEvaluateAsBool(inner_array_type->dim());

// If we can't determine signedness we bail, as it's required metadata
// for us to return.
if (!is_signed.has_value()) {
return std::nullopt;
}

return BitVectorMetadata{.bit_count = array_type->dim(),
.is_signed = is_signed.value(),
.kind = kind};
}
}

// bits[..], uN[..], and sN[..] are bit-vector types but a represented with
// ArrayTypeAnnotations.
const BuiltinTypeAnnotation* builtin_element_type =
Expand Down
8 changes: 5 additions & 3 deletions xls/dslx/frontend/ast_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,11 @@ struct BitVectorMetadata {
};

// Returns metadata about the bit-vector type if `type_annotation` refers to a
// type whose underlying representation is a bit-vector. Examples include u32,
// s10, uN[42], bits[11], enums, etc, and aliases of these types. Returns
// std::nullopt otherwise.
// type whose underlying representation is a bit-vector. Returns std::nullopt
// otherwise.
//
// Examples include `u32`, `s10`, `uN[42]`, `bits[11]`, enums, etc, and
// aliases of these types.
std::optional<BitVectorMetadata> ExtractBitVectorMetadata(
const TypeAnnotation* type_annotation);

Expand Down
23 changes: 22 additions & 1 deletion xls/dslx/frontend/ast_utils_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct TheStruct {
l: MyStruct,
m: MyTuple,
n: MyArray,
x: xN[bool:0x0][44],
y: xN[bool:0x1][44],
}
)";
FileTable file_table;
Expand All @@ -76,7 +78,10 @@ struct TheStruct {
}
}
ASSERT_NE(the_struct_def, nullptr);
auto get_type_metadata = [&](std::string_view name) {

// Helper that extracts the metadata associated with a given field name.
auto get_type_metadata =
[&](std::string_view name) -> std::optional<BitVectorMetadata> {
for (const StructMember& member : the_struct_def->members()) {
if (member.name == name) {
return ExtractBitVectorMetadata(member.type);
Expand Down Expand Up @@ -124,6 +129,22 @@ struct TheStruct {
EXPECT_FALSE(get_type_metadata("i")->is_signed);
EXPECT_EQ(get_type_metadata("i")->kind, BitVectorKind::kEnumTypeAlias);

// Note: for an `xN` we expect the bit count to be an expression.
ASSERT_TRUE(get_type_metadata("x").has_value());
ASSERT_TRUE(std::holds_alternative<Expr*>(get_type_metadata("x")->bit_count));
EXPECT_EQ(std::get<Expr*>(get_type_metadata("x")->bit_count)->ToString(),
"44");
EXPECT_FALSE(get_type_metadata("x")->is_signed);
EXPECT_EQ(get_type_metadata("x")->kind, BitVectorKind::kBitType);

// Note: for an `xN` we expect the bit count to be an expression.
ASSERT_TRUE(get_type_metadata("y").has_value());
ASSERT_TRUE(std::holds_alternative<Expr*>(get_type_metadata("y")->bit_count));
EXPECT_EQ(std::get<Expr*>(get_type_metadata("y")->bit_count)->ToString(),
"44");
EXPECT_TRUE(get_type_metadata("y")->is_signed);
EXPECT_EQ(get_type_metadata("y")->kind, BitVectorKind::kBitType);

EXPECT_FALSE(get_type_metadata("j").has_value());
EXPECT_FALSE(get_type_metadata("k").has_value());
EXPECT_FALSE(get_type_metadata("l").has_value());
Expand Down
37 changes: 37 additions & 0 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2114,6 +2114,43 @@ TEST_F(ParserTest, ArrayTypeAnnotation) {
EXPECT_EQ(array_type->element_type()->ToString(), "u8");
}

// Tests parsing of a type annotation made from the `xN` bits constructor.
TEST_F(ParserTest, xNTypeAnnotation) {
std::string s = "xN[false][128]";
scanner_.emplace(file_table_, Fileno(0), s);
parser_.emplace("test", &*scanner_);
Bindings bindings;
XLS_ASSERT_OK_AND_ASSIGN(TypeAnnotation * ta,
ParseTypeAnnotation(parser_.value(), bindings));

// Outer array type gives the bit count.
auto* outer_array_type = dynamic_cast<ArrayTypeAnnotation*>(ta);
EXPECT_NE(outer_array_type, nullptr);
auto* element_type = outer_array_type->element_type();
EXPECT_NE(element_type, nullptr);
auto* outer_dim = outer_array_type->dim();
EXPECT_NE(outer_dim, nullptr);
auto* outer_number = dynamic_cast<Number*>(outer_dim);
EXPECT_NE(outer_number, nullptr);
EXPECT_EQ(outer_number->ToString(), "128");

// Inner array type gives the signedness.
auto* inner_array_type = dynamic_cast<ArrayTypeAnnotation*>(element_type);
EXPECT_NE(inner_array_type, nullptr);
auto* inner_dim = inner_array_type->dim();
EXPECT_NE(inner_dim, nullptr);
auto* inner_number = dynamic_cast<Number*>(inner_dim);
EXPECT_NE(inner_number, nullptr);
EXPECT_EQ(inner_number->ToString(), "false");

// The innermost element is an `xN` builtin type annotation.
auto* builtin_type =
dynamic_cast<BuiltinTypeAnnotation*>(inner_array_type->element_type());
EXPECT_NE(builtin_type, nullptr);
EXPECT_EQ(builtin_type->builtin_type(), BuiltinType::kXN);
EXPECT_EQ(builtin_type->GetBitCount(), 0);
}

TEST_F(ParserTest, TupleArrayAndInt) {
Expr* e = RoundTripExpr("(u8[4]:[1, 2, 3, 4], 7)", {}, false, std::nullopt);
auto* tuple = dynamic_cast<XlsTuple*>(e);
Expand Down
Loading

0 comments on commit 1dfce58

Please sign in to comment.