Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions xls/dslx/frontend/parser_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2677,6 +2677,27 @@ TEST_F(ParserTest, ModuleWithParametricProcAlias) {
proc Bar = Foo<3, 4>;)");
}

TEST_F(ParserTest, ModuleWithParametricProcAliasCallingParametricFn) {
RoundTrip(R"(fn bar<Y: u32>(i: uN[Y]) -> uN[Y] {
i + i
}
proc Foo<N: u32> {
c: chan<uN[N]> out;
config(output_c: chan<uN[N]> out) {
(output_c,)
}
init {
uN[N]:1
}
next(i: uN[N]) {
let result = bar<N>(i);
let tok = send(join(), c, result);
result + uN[N]:1
}
}
proc Bar = Foo<16>;)");
}

TEST_F(ParserTest, ModuleWithPublicParametricProcAlias) {
RoundTrip(R"(pub proc Foo<A: u32, B: u32> {
config() {}
Expand Down
18 changes: 11 additions & 7 deletions xls/dslx/ir_convert/function_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3598,14 +3598,18 @@ absl::Status FunctionConverter::HandleProcNextFunction(
parametric_type->GetTotalBitCount());
XLS_ASSIGN_OR_RETURN(int64_t bit_count, parametric_width_ctd.GetAsInt64());
Value param_value;
if (parametric_value->IsSigned()) {
XLS_ASSIGN_OR_RETURN(int64_t bit_value,
parametric_value->GetBitValueViaSign());
param_value = Value(SBits(bit_value, bit_count));
if (parametric_value->IsBits()) {
if (parametric_value->IsSigned()) {
XLS_ASSIGN_OR_RETURN(int64_t bit_value,
parametric_value->GetBitValueViaSign());
param_value = Value(SBits(bit_value, bit_count));
} else {
XLS_ASSIGN_OR_RETURN(uint64_t bit_value,
parametric_value->GetBitValueViaSign());
param_value = Value(UBits(bit_value, bit_count));
}
} else {
XLS_ASSIGN_OR_RETURN(uint64_t bit_value,
parametric_value->GetBitValueViaSign());
param_value = Value(UBits(bit_value, bit_count));
XLS_ASSIGN_OR_RETURN(param_value, InterpValueToValue(*parametric_value));
}
DefConst(parametric_binding, param_value);
XLS_RETURN_IF_ERROR(
Expand Down
144 changes: 144 additions & 0 deletions xls/dslx/ir_convert/ir_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,42 @@ pub proc FooAlias = Foo<16>;
ExpectIr(converted);
}

TEST_P(IrConverterWithBothTypecheckVersionsTest,
HandlesParametricProcAliasCallingParametricFn) {
if (GetParam() == TypeInferenceVersion::kVersion1) {
// Proc aliases are not supported in TIv1.
return;
}

constexpr std::string_view program = R"(
fn bar<Y: u32>(i: uN[Y]) -> uN[Y] {
i+i
}

proc Foo<N: u32> {
c: chan<uN[N]> out;
init { uN[N]:1 }
config(output_c: chan<uN[N]> out) {
(output_c,)
}
next(i: uN[N]) {
let result = bar<N>(i);
let tok = send(join(), c, result);
result + uN[N]:1
}
}

pub proc FooAlias = Foo<16>;
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(program, "FooAlias", import_data,
kNoPosOptions));
ExpectIr(converted);
}

TEST_P(IrConverterWithBothTypecheckVersionsTest,
HandlesProcAliasToImportedProc) {
if (GetParam() == TypeInferenceVersion::kVersion1) {
Expand Down Expand Up @@ -2730,6 +2766,83 @@ pub proc FooAlias = imported::Foo<16>;
ExpectIr(converted);
}

TEST_P(IrConverterWithBothTypecheckVersionsTest,
HandlesProcAliasToNonBitsParametricProc) {
if (GetParam() == TypeInferenceVersion::kVersion1) {
// Proc aliases are not supported in TIv1.
return;
}

constexpr std::string_view program = R"(
struct MyStruct {
a: u32,
b: u32,
}

pub proc Foo<CONFIG: MyStruct> {
c: chan<uN[CONFIG.a]> out;
init { uN[CONFIG.a]:1 }
config(output_c: chan<uN[CONFIG.a]> out) {
(output_c,)
}
next(i: uN[CONFIG.a]) {
let tok = send(join(), c, i);
i + CONFIG.b as uN[CONFIG.a]
}
}

const CONFIG = MyStruct{a: u32:32, b: u32:2};
proc FooAlias = Foo<CONFIG>;
)";

ImportData import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(program, "FooAlias", import_data,
kNoPosOptions));
ExpectIr(converted);
}

// TODO: google/xls#3353 - enable this test when all the issues have been fixed.
TEST_P(IrConverterWithBothTypecheckVersionsTest,
DISABLED_HandlesSpawnOfNonBitsParametricProc) {
constexpr std::string_view program = R"(
struct MyStruct {
a: u32,
b: u32,
}

proc Foo<CONFIG: MyStruct> {
c: chan<uN[CONFIG.a]> out;
init { uN[CONFIG.a]:1 }
config(output_c: chan<uN[CONFIG.a]> out) {
(output_c,)
}
next(i: uN[CONFIG.a]) {
let tok = send(join(), c, i);
i + CONFIG.b as uN[CONFIG.a]
}
}

const CONFIG = MyStruct{a: u32:32, b: u32:2};
proc Top {
init { () }
config() {
let (p, c) = chan<uN[CONFIG.a]>("my_chan");
spawn Foo<CONFIG>(p);
()
}
next(state: ()) { () }
}
)";

ImportData import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(program, "Top", import_data, kNoPosOptions));
ExpectIr(converted);
}

TEST_P(IrConverterWithBothTypecheckVersionsTest, HandlesProcWithTypeAlias) {
constexpr std::string_view program = R"(
proc P {
Expand Down Expand Up @@ -6368,6 +6481,37 @@ pub proc FooAlias = Foo<16>;
ExpectIr(converted);
}

TEST_P(ProcScopedChannelsIrConverterTest,
ProcScopedParametricProcAliasCallingParametricFn) {
constexpr std::string_view program = R"(
fn bar<Y: u32>(i: uN[Y]) -> uN[Y] {
i+i
}

proc Foo<N: u32> {
c: chan<uN[N]> out;
init { uN[N]:1 }
config(output_c: chan<uN[N]> out) {
(output_c,)
}
next(i: uN[N]) {
let result = bar<N>(i);
let tok = send(join(), c, result);
result + uN[N]:1
}
}

pub proc FooAlias = Foo<16>;
)";

auto import_data = CreateImportDataForTest();
XLS_ASSERT_OK_AND_ASSIGN(
std::string converted,
ConvertOneFunctionForTest(program, "FooAlias", import_data,
kProcScopedChannelOptions));
ExpectIr(converted);
}

TEST_P(ProcScopedChannelsIrConverterTest, ProcScopedProcAliasToImportedProc) {
ImportData import_data = CreateImportDataForTest();

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package test_module

file_number 0 "test_module.x"

chan test_module__output_c(bits[16], id=0, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)

fn __test_module__bar__16(i: bits[16] id=1) -> bits[16] {
Y: bits[5] = literal(value=16, id=2)
ret add.3: bits[16] = add(i, i, id=3)
}

top proc __test_module__FooAlias_next(__state: bits[16], init={1}) {
__state: bits[16] = state_read(state_element=__state, id=5)
result: bits[16] = invoke(__state, to_apply=__test_module__bar__16, id=8)
literal.11: bits[16] = literal(value=1, id=11)
after_all.9: token = after_all(id=9)
literal.6: bits[1] = literal(value=1, id=6)
add.12: bits[16] = add(result, literal.11, id=12)
__token: token = literal(value=token, id=4)
N: bits[32] = literal(value=16, id=7)
tok: token = send(after_all.9, result, predicate=literal.6, channel=test_module__output_c, id=10)
next_value.13: () = next_value(param=__state, value=add.12, id=13)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package test_module

file_number 0 "test_module.x"

chan test_module__output_c(bits[32], id=0, kind=streaming, ops=send_only, flow_control=ready_valid, strictness=proven_mutually_exclusive)

top proc __test_module__FooAlias_next(__state: bits[32], init={1}) {
CONFIG: (bits[32], bits[32]) = literal(value=(32, 2), id=4)
CONFIG_b: bits[32] = tuple_index(CONFIG, index=1, id=7)
__state: bits[32] = state_read(state_element=__state, id=2)
zero_ext.8: bits[32] = zero_ext(CONFIG_b, new_bit_count=32, id=8)
after_all.5: token = after_all(id=5)
literal.3: bits[1] = literal(value=1, id=3)
add.9: bits[32] = add(__state, zero_ext.8, id=9)
__token: token = literal(value=token, id=1)
tok: token = send(after_all.5, __state, predicate=literal.3, channel=test_module__output_c, id=6)
next_value.10: () = next_value(param=__state, value=add.9, id=10)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package test_module

file_number 0 "test_module.x"

fn __test_module__bar__16(i: bits[16] id=1) -> bits[16] {
Y: bits[5] = literal(value=16, id=2)
ret add.3: bits[16] = add(i, i, id=3)
}

top proc __test_module__FooAlias_next<output_c: bits[16] out>(__state: bits[16], init={1}) {
chan_interface output_c(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
__state: bits[16] = state_read(state_element=__state, id=5)
result: bits[16] = invoke(__state, to_apply=__test_module__bar__16, id=9)
literal.12: bits[16] = literal(value=1, id=12)
after_all.10: token = after_all(id=10)
literal.6: bits[1] = literal(value=1, id=6)
add.13: bits[16] = add(result, literal.12, id=13)
__token: token = literal(value=token, id=4)
N: bits[32] = literal(value=16, id=7)
tuple.8: () = tuple(id=8)
tok: token = send(after_all.10, result, predicate=literal.6, channel=output_c, id=11)
next_value.14: () = next_value(param=__state, value=add.13, id=14)
}
1 change: 1 addition & 0 deletions xls/dslx/type_system_v2/inference_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ class InferenceTableImpl : public InferenceTable {
ParametricContext* result = context.get();
parametric_contexts_.push_back(std::move(context));
mutable_parametric_context_data_.emplace(result, std::move(mutable_data));
SetParametricEnv(result, env);
return result;
}

Expand Down