Skip to content

Commit

Permalink
Add element_count<T>() builtin.
Browse files Browse the repository at this point in the history
This is like `array_size(value of type T)` for non-bits-like types, and like `bit_count<T>()` for bits-like types.

Concat in type_system_v2 will use this to support concat of arrays or bit vectors.

PiperOrigin-RevId: 726161251
  • Loading branch information
richmckeever authored and copybara-github committed Feb 12, 2025
1 parent b93ead0 commit a0a134b
Show file tree
Hide file tree
Showing 16 changed files with 232 additions and 4 deletions.
25 changes: 25 additions & 0 deletions docs_src/dslx_std.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,31 @@ fn test_bit_count_size() {
}
```

### `element_count`

`element_count` returns the number of elements in the given type.

* For an array, it is the same as `array_size` for a value of the type.
* For a tuple or struct, it is the number of top-level members.
* For all other types, it is the same as `bit_count`.

```
fn element_count<T: type>() -> u32
```

```dslx
struct MyPoint { x: u32, y: u32 }
#[test]
fn test_bit_count_size() {
assert_eq(element_count<u32[4]>(), u32:4);
assert_eq(element_count<s64>(), u32:64);
assert_eq(element_count<bool>(), u32:1);
assert_eq(element_count<MyPoint>(), u32:2);
assert_eq(element_count<(u32, (u32, u32))>(), u32:2);
}
```

### `widening_cast`, `checked_cast`

`widening_cast` and `checked_cast` cast bits-type values to bits-type values
Expand Down
18 changes: 18 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,21 @@ absl::Status BytecodeEmitter::HandleBuiltinBitCount(const Invocation* node) {
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleBuiltinElementCount(
const Invocation* node) {
VLOG(5) << "BytecodeEmitter::HandleInvocation - ElementCount @ "
<< node->span().ToString(file_table());

const auto* annotation =
std::get<TypeAnnotation*>(node->explicit_parametrics()[0]);
XLS_ASSIGN_OR_RETURN(Type * type, GetTypeOfNode(annotation, type_info_));
XLS_ASSIGN_OR_RETURN(InterpValue element_count,
GetElementCountAsInterpValue(type));
bytecode_.push_back(
Bytecode(node->span(), Bytecode::Op::kLiteral, element_count));
return absl::OkStatus();
}

absl::Status BytecodeEmitter::HandleChannelDecl(const ChannelDecl* node) {
// Channels are created as constexpr values during type deduction/constexpr
// evaluation, since they're concrete values that need to be shared amongst
Expand Down Expand Up @@ -1099,6 +1114,9 @@ absl::Status BytecodeEmitter::HandleInvocation(const Invocation* node) {
if (name_ref->identifier() == "bit_count") {
return HandleBuiltinBitCount(node);
}
if (name_ref->identifier() == "element_count") {
return HandleBuiltinElementCount(node);
}

if (name_ref->identifier() == "checked_cast") {
return HandleBuiltinCheckedCast(node);
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode/bytecode_emitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class BytecodeEmitter : public ExprVisitor {
absl::Status HandleBuiltinCheckedCast(const Invocation* node);
absl::Status HandleBuiltinWideningCast(const Invocation* node);
absl::Status HandleBuiltinBitCount(const Invocation* node);
absl::Status HandleBuiltinElementCount(const Invocation* node);
absl::Status HandleBuiltinSend(const Invocation* node);
absl::Status HandleBuiltinSendIf(const Invocation* node);
absl::Status HandleBuiltinRecv(const Invocation* node);
Expand Down
54 changes: 54 additions & 0 deletions xls/dslx/bytecode/bytecode_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1536,6 +1536,60 @@ uadd)";
EXPECT_EQ(kWant, got);
}

TEST(BytecodeEmitterTest, ElementCount) {
constexpr std::string_view kProgram = R"(
struct S {
a: u32,
b: u32
}
struct T<N: u32> {
a: uN[N]
}
#[test]
fn main() -> u32 {
element_count<u32>() +
element_count<s64>() +
element_count<u32[u32:4]>() +
element_count<u32[u32:4][u32:5]>() +
element_count<bool>() +
element_count<S>() +
element_count<T<u32:4>>() +
element_count<(u32, bool)>()
}
)";

constexpr std::string_view kWant = R"(literal u32:32
literal u32:64
uadd
literal u32:4
uadd
literal u32:5
uadd
literal u32:1
uadd
literal u32:2
uadd
literal u32:1
uadd
literal u32:2
uadd)";

ImportData import_data(CreateImportDataForTest());
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<BytecodeFunction> bf,
EmitBytecodes(&import_data, kProgram, "main"));

std::string got = absl::StrJoin(
bf->bytecodes(), "\n",
[&import_data](std::string* out, const Bytecode& b) {
absl::StrAppend(
out, b.ToString(import_data.file_table(), /*source_locs=*/false));
});

EXPECT_EQ(kWant, got);
}

TEST(BytecodeEmitterTest, ParameterizedTypeDefToImportedEnum) {
constexpr std::string_view kImported = R"(
pub struct ImportedStruct<X: u32> {
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/bytecode/bytecode_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,7 @@ absl::Status BytecodeInterpreter::RunBuiltinFn(const Bytecode& bytecode,
case Builtin::kCheckedCast:
case Builtin::kWideningCast:
case Builtin::kBitCount:
case Builtin::kElementCount:
return absl::UnimplementedError(absl::StrFormat(
"BytecodeInterpreter: builtin function \"%s\" not yet implemented.",
BuiltinToString(builtin)));
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/dslx_builtins.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace xls::dslx {
X("array_rev", kArrayRev) \
X("array_size", kArraySize) \
X("bit_count", kBitCount) \
X("element_count", kElementCount) \
X("assert_eq", kAssertEq) \
X("assert_lt", kAssertLt) \
X("bit_slice_update", kBitSliceUpdate) \
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/frontend/builtins_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ const absl::flat_hash_map<std::string, BuiltinsData>& GetParametricBuiltins() {
{"array_size", {.signature = "(T[N]) -> u32", .is_ast_node = false}},

{"bit_count", {.signature = "() -> u32", .is_ast_node = false}},
{"element_count", {.signature = "() -> u32", .is_ast_node = false}},

// Bitwise reduction ops.
{"and_reduce", {.signature = "(uN[N]) -> u1", .is_ast_node = false}},
Expand Down
5 changes: 5 additions & 0 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2268,6 +2268,11 @@ TEST_F(ParserTest, BitCount) {
/*populate_dslx_builtins=*/true, "bit_count<u32>()");
}

TEST_F(ParserTest, ElementCount) {
RoundTripExpr("element_count<u32[u32:5]>()", {},
/*populate_dslx_builtins=*/true, "element_count<u32[u32:5]>()");
}

TEST_F(ParserTest, CastOfCastEnum) {
RoundTrip(R"(enum MyEnum : u3 {
SOME_VALUE = 0,
Expand Down
10 changes: 10 additions & 0 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,15 @@ absl::Status FunctionConverter::HandleBuiltinBitCount(const Invocation* node) {
return absl::OkStatus();
}

absl::Status FunctionConverter::HandleBuiltinElementCount(
const Invocation* node) {
// Like bit_count, element_count is always constexpr.
XLS_ASSIGN_OR_RETURN(InterpValue iv, current_type_info_->GetConstExpr(node));
XLS_ASSIGN_OR_RETURN(Value v, InterpValueToValue(iv));
DefConst(node, v);
return absl::OkStatus();
}

absl::Status FunctionConverter::HandleBuiltinWideningCast(
const Invocation* node) {
XLS_RET_CHECK_EQ(node->args().size(), 1);
Expand Down Expand Up @@ -2374,6 +2383,7 @@ absl::Status FunctionConverter::HandleInvocation(const Invocation* node) {
{"widening_cast", &FunctionConverter::HandleBuiltinWideningCast},
{"checked_cast", &FunctionConverter::HandleBuiltinCheckedCast},
{"bit_count", &FunctionConverter::HandleBuiltinBitCount},
{"element_count", &FunctionConverter::HandleBuiltinElementCount},
{"update", &FunctionConverter::HandleBuiltinUpdate},
{"umulp", &FunctionConverter::HandleBuiltinUMulp},
{"smulp", &FunctionConverter::HandleBuiltinSMulp},
Expand Down
1 change: 1 addition & 0 deletions xls/dslx/ir_convert/function_converter.h
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,7 @@ class FunctionConverter {
absl::Status HandleBuiltinUMulp(const Invocation* node);
absl::Status HandleBuiltinWideningCast(const Invocation* node);
absl::Status HandleBuiltinBitCount(const Invocation* node);
absl::Status HandleBuiltinElementCount(const Invocation* node);
absl::Status HandleBuiltinXorReduce(const Invocation* node);

absl::Status HandleBuiltinJoin(const Invocation* node);
Expand Down
29 changes: 29 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3173,6 +3173,35 @@ fn main() -> u32 {
ExpectIr(converted, TestName());
}

TEST(IrConverterTest, ElementCount) {
constexpr std::string_view kProgram = R"(
struct S {
a: u32,
b: u32
}
struct T<N: u32> {
a: uN[N]
}
fn main() -> u32 {
element_count<u32>() +
element_count<s64>() +
element_count<u32[u32:4]>() +
element_count<u32[u32:4][u32:5]>() +
element_count<bool>() +
element_count<S>() +
element_count<T<u32:4>>() +
element_count<(u32, bool)>()
}
)";

XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertModuleForTest(kProgram, ConvertOptions{.emit_positions = false}));
ExpectIr(converted, TestName());
}

TEST(IrConverterTest, MapInvocationWithBuiltinFunction) {
constexpr std::string_view program =
R"(
Expand Down
21 changes: 21 additions & 0 deletions xls/dslx/ir_convert/testdata/ir_converter_test_ElementCount.ir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package test_module

file_number 0 "test_module.x"

fn __test_module__main() -> bits[32] {
literal.1: bits[32] = literal(value=32, id=1)
literal.2: bits[32] = literal(value=64, id=2)
add.3: bits[32] = add(literal.1, literal.2, id=3)
literal.4: bits[32] = literal(value=4, id=4)
add.5: bits[32] = add(add.3, literal.4, id=5)
literal.6: bits[32] = literal(value=5, id=6)
add.7: bits[32] = add(add.5, literal.6, id=7)
literal.8: bits[32] = literal(value=1, id=8)
add.9: bits[32] = add(add.7, literal.8, id=9)
literal.10: bits[32] = literal(value=2, id=10)
add.11: bits[32] = add(add.9, literal.10, id=11)
literal.12: bits[32] = literal(value=1, id=12)
add.13: bits[32] = add(add.11, literal.12, id=13)
literal.14: bits[32] = literal(value=2, id=14)
ret add.15: bits[32] = add(add.13, literal.14, id=15)
}
18 changes: 18 additions & 0 deletions xls/dslx/type_system/deduce_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1059,4 +1059,22 @@ absl::StatusOr<InterpValue> GetBitCountAsInterpValue(const Type* type) {
return InterpValue::MakeU32(bit_count);
}

absl::StatusOr<InterpValue> GetElementCountAsInterpValue(const Type* type) {
if (type->IsMeta()) {
XLS_ASSIGN_OR_RETURN(type, UnwrapMetaType(*type));
}
if (const auto* array_type = dynamic_cast<const ArrayType*>(type)) {
XLS_ASSIGN_OR_RETURN(int64_t size, array_type->size().GetAsInt64());
CHECK(static_cast<uint32_t>(size) == size);
return InterpValue::MakeU32(size);
}
if (const auto* tuple_type = dynamic_cast<const TupleType*>(type)) {
return InterpValue::MakeU32(tuple_type->members().size());
}
if (const auto* struct_type = dynamic_cast<const StructType*>(type)) {
return InterpValue::MakeU32(struct_type->members().size());
}
return GetBitCountAsInterpValue(type);
}

} // namespace xls::dslx
11 changes: 11 additions & 0 deletions xls/dslx/type_system/deduce_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,17 @@ void WarnOnInappropriateConstantName(std::string_view identifier,
// Gets the total bit count of the given `type` as a u32 `InterpValue`.
absl::StatusOr<InterpValue> GetBitCountAsInterpValue(const Type* type);

// Gets the element count of the given `type` as a u32 `InterpValue`. The
// element count depends on the kind of type:
//
// kind element count
// -----------------------------
// bits-like total bit count
// array number of elements indexable by first index
// tuple number of top-level members
// struct number of top-level members
absl::StatusOr<InterpValue> GetElementCountAsInterpValue(const Type* type);

} // namespace xls::dslx

#endif // XLS_DSLX_TYPE_SYSTEM_DEDUCE_UTILS_H_
11 changes: 7 additions & 4 deletions xls/dslx/type_system/typecheck_invocation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -319,15 +319,18 @@ TypecheckParametricBuiltinInvocation(DeduceCtx* ctx,
invocation, InterpValue::MakeU32(static_cast<int32_t>(array_size)));
}

// bit_count is similar to array_size, but uses the parametric argument rather
// than a value.
if (callee_nameref->identifier() == "bit_count") {
// bit_count and element_count are similar to array_size, but uses the
// parametric argument rather than a value.
if (callee_nameref->identifier() == "bit_count" ||
callee_nameref->identifier() == "element_count") {
auto* annotation =
std::get<TypeAnnotation*>(invocation->explicit_parametrics()[0]);
XLS_ASSIGN_OR_RETURN(std::unique_ptr<Type> type,
ctx->DeduceAndResolve(annotation));
XLS_ASSIGN_OR_RETURN(InterpValue value,
GetBitCountAsInterpValue(type.get()));
callee_nameref->identifier() == "element_count"
? GetElementCountAsInterpValue(type.get())
: GetBitCountAsInterpValue(type.get()));
ctx->type_info()->NoteConstExpr(invocation, value);
}

Expand Down
29 changes: 29 additions & 0 deletions xls/dslx/type_system/typecheck_module_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4380,6 +4380,35 @@ fn main() -> u32 {
})"));
}

TEST(TypecheckTest, ElementCount) {
XLS_ASSERT_OK(Typecheck(R"(
struct S {
a: u32,
b: u32
}
struct T<N: u32> {
a: uN[N],
b: u32
}
type A = S;
type B = T;
fn main() -> u32 {
element_count<u32>() +
element_count<s64>() +
element_count<u32[u32:4]>() +
element_count<u32[u32:4][u32:5]>() +
element_count<bool>() +
element_count<S>() +
element_count<T<u32:4>>() +
element_count<(u32, bool)>() +
element_count<A>() +
element_count<B<u32:5>>()
})"));
}

TEST(TypecheckTest, BitCountAsConstExpr) {
XLS_ASSERT_OK(Typecheck(R"(
fn main() -> u64 {
Expand Down

0 comments on commit a0a134b

Please sign in to comment.