Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 110 additions & 15 deletions xls/public/c_api_dslx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,23 @@ struct InvocationCalleeDataArray {
std::vector<xls::dslx::InvocationCalleeData> entries;
};

template <typename T>
xls::dslx::ModuleMember* FindModuleMemberForNode(T* node) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like this would be better as bool IsModuleMember(const AstNode* node) The current name kind of implies that it's going to dig up the ModuleMember that the node is in.

if (node == nullptr) {
return nullptr;
}
xls::dslx::Module* module = node->owner();
if (module == nullptr) {
return nullptr;
}
for (xls::dslx::ModuleMember& member : module->top()) {
if (std::holds_alternative<T*>(member) && std::get<T*>(member) == node) {
return &member;
}
}
return nullptr;
}

} // namespace

extern "C" {
Expand Down Expand Up @@ -382,13 +399,47 @@ bool xls_dslx_typechecked_module_clone_removing_functions(
char** error_out, struct xls_dslx_typechecked_module** result_out) {
CHECK(error_out != nullptr);
CHECK(result_out != nullptr);
auto fail = [&](absl::string_view message) {
*error_out = xls::ToOwnedCString(std::string(message));
*result_out = nullptr;
return false;
};
if (function_count != 0 && functions == nullptr) {
return fail("functions array is null");
}
std::vector<xls_dslx_module_member*> members;
members.reserve(function_count);
for (size_t i = 0; i < function_count; ++i) {
auto* fn = reinterpret_cast<xls::dslx::Function*>(functions[i]);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(fn);
if (member == nullptr) {
return fail("function does not belong to the provided module");
}
members.push_back(reinterpret_cast<xls_dslx_module_member*>(member));
}
return xls_dslx_typechecked_module_clone_removing_members(
tm, members.empty() ? nullptr : members.data(), function_count,
install_subject, import_data, error_out, result_out);
}

bool xls_dslx_typechecked_module_clone_removing_members(
struct xls_dslx_typechecked_module* tm,
struct xls_dslx_module_member* members[], size_t member_count,
const char* install_subject, struct xls_dslx_import_data* import_data,
char** error_out, struct xls_dslx_typechecked_module** result_out) {
CHECK(error_out != nullptr);
CHECK(result_out != nullptr);
*error_out = nullptr;
*result_out = nullptr;

if (tm == nullptr || import_data == nullptr) {
*error_out = xls::ToOwnedCString("null argument provided");
return false;
}
if (member_count != 0 && members == nullptr) {
*error_out = xls::ToOwnedCString("members array is null");
return false;
}

auto* cpp_tm = reinterpret_cast<xls::dslx::TypecheckedModule*>(tm);
auto* cpp_import_data = reinterpret_cast<xls::dslx::ImportData*>(import_data);
Expand All @@ -400,19 +451,20 @@ bool xls_dslx_typechecked_module_clone_removing_functions(
}

std::vector<const xls::dslx::AstNode*> nodes_to_remove;
nodes_to_remove.reserve(function_count);
for (size_t i = 0; i < function_count; ++i) {
if (functions == nullptr || functions[i] == nullptr) {
*error_out = xls::ToOwnedCString("functions array contains null entry");
nodes_to_remove.reserve(member_count);
for (size_t i = 0; i < member_count; ++i) {
if (members[i] == nullptr) {
*error_out = xls::ToOwnedCString("members array contains null entry");
return false;
}
auto* fn = reinterpret_cast<xls::dslx::Function*>(functions[i]);
if (fn->owner() != cpp_tm->module) {
auto* cpp_member = reinterpret_cast<xls::dslx::ModuleMember*>(members[i]);
const xls::dslx::AstNode* node = xls::dslx::ToAstNode(*cpp_member);
if (node == nullptr || node->owner() != cpp_tm->module) {
*error_out = xls::ToOwnedCString(
"function does not belong to the provided module");
"module member does not belong to the provided module");
return false;
}
nodes_to_remove.push_back(fn);
nodes_to_remove.push_back(node);
}

absl::StatusOr<std::unique_ptr<xls::dslx::Module>> cloned_module_or =
Expand Down Expand Up @@ -556,6 +608,56 @@ struct xls_dslx_function* xls_dslx_module_member_get_function(
return nullptr;
}

struct xls_dslx_quickcheck* xls_dslx_module_member_get_quickcheck(
struct xls_dslx_module_member* member) {
auto* cpp_member = reinterpret_cast<xls::dslx::ModuleMember*>(member);
auto* cpp_qc = std::get<xls::dslx::QuickCheck*>(*cpp_member);
return reinterpret_cast<xls_dslx_quickcheck*>(cpp_qc);
}

struct xls_dslx_module_member* xls_dslx_module_member_from_constant_def(
struct xls_dslx_constant_def* constant_def) {
auto* cpp_constant_def =
reinterpret_cast<xls::dslx::ConstantDef*>(constant_def);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(cpp_constant_def);
return reinterpret_cast<xls_dslx_module_member*>(member);
}

struct xls_dslx_module_member* xls_dslx_module_member_from_struct_def(
struct xls_dslx_struct_def* struct_def) {
auto* cpp_struct_def = reinterpret_cast<xls::dslx::StructDef*>(struct_def);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(cpp_struct_def);
return reinterpret_cast<xls_dslx_module_member*>(member);
}

struct xls_dslx_module_member* xls_dslx_module_member_from_enum_def(
struct xls_dslx_enum_def* enum_def) {
auto* cpp_enum_def = reinterpret_cast<xls::dslx::EnumDef*>(enum_def);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(cpp_enum_def);
return reinterpret_cast<xls_dslx_module_member*>(member);
}

struct xls_dslx_module_member* xls_dslx_module_member_from_type_alias(
struct xls_dslx_type_alias* type_alias) {
auto* cpp_type_alias = reinterpret_cast<xls::dslx::TypeAlias*>(type_alias);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(cpp_type_alias);
return reinterpret_cast<xls_dslx_module_member*>(member);
}

struct xls_dslx_module_member* xls_dslx_module_member_from_function(
struct xls_dslx_function* function) {
auto* cpp_function = reinterpret_cast<xls::dslx::Function*>(function);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(cpp_function);
return reinterpret_cast<xls_dslx_module_member*>(member);
}

struct xls_dslx_module_member* xls_dslx_module_member_from_quickcheck(
struct xls_dslx_quickcheck* quickcheck) {
auto* cpp_qc = reinterpret_cast<xls::dslx::QuickCheck*>(quickcheck);
xls::dslx::ModuleMember* member = FindModuleMemberForNode(cpp_qc);
return reinterpret_cast<xls_dslx_module_member*>(member);
}

bool xls_dslx_function_is_parametric(struct xls_dslx_function* fn) {
auto* cpp_function = reinterpret_cast<xls::dslx::Function*>(fn);
return cpp_function->IsParametric();
Expand Down Expand Up @@ -689,13 +791,6 @@ struct xls_dslx_function* xls_dslx_call_graph_get_callee_function(
const_cast<xls::dslx::Function*>(fn));
}

struct xls_dslx_quickcheck* xls_dslx_module_member_get_quickcheck(
struct xls_dslx_module_member* member) {
auto* cpp_member = reinterpret_cast<xls::dslx::ModuleMember*>(member);
auto* cpp_qc = std::get<xls::dslx::QuickCheck*>(*cpp_member);
return reinterpret_cast<xls_dslx_quickcheck*>(cpp_qc);
}

struct xls_dslx_function* xls_dslx_quickcheck_get_function(
struct xls_dslx_quickcheck* quickcheck) {
auto* cpp_qc = reinterpret_cast<xls::dslx::QuickCheck*>(quickcheck);
Expand Down
36 changes: 30 additions & 6 deletions xls/public/c_api_dslx.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ bool xls_dslx_typechecked_module_clone_removing_functions(
const char* install_subject, struct xls_dslx_import_data* import_data,
char** error_out, struct xls_dslx_typechecked_module** result_out);

bool xls_dslx_typechecked_module_clone_removing_members(
struct xls_dslx_typechecked_module* tm,
struct xls_dslx_module_member* members[], size_t member_count,
const char* install_subject, struct xls_dslx_import_data* import_data,
char** error_out, struct xls_dslx_typechecked_module** result_out);

void xls_dslx_typechecked_module_free(struct xls_dslx_typechecked_module* tm);

struct xls_dslx_module* xls_dslx_typechecked_module_get_module(
Expand Down Expand Up @@ -234,6 +240,30 @@ struct xls_dslx_type_alias* xls_dslx_module_member_get_type_alias(
struct xls_dslx_function* xls_dslx_module_member_get_function(
struct xls_dslx_module_member*);

// Returns the QuickCheck AST node from the given module member. The caller
// should ensure the module member kind is
// `xls_dslx_module_member_kind_quick_check`.
struct xls_dslx_quickcheck* xls_dslx_module_member_get_quickcheck(
struct xls_dslx_module_member*);

struct xls_dslx_module_member *xls_dslx_module_member_from_constant_def(
struct xls_dslx_constant_def* constant_def);

struct xls_dslx_module_member *xls_dslx_module_member_from_struct_def(
struct xls_dslx_struct_def* struct_def);

struct xls_dslx_module_member *xls_dslx_module_member_from_enum_def(
struct xls_dslx_enum_def* enum_def);

struct xls_dslx_module_member *xls_dslx_module_member_from_type_alias(
struct xls_dslx_type_alias* type_alias);

struct xls_dslx_module_member *xls_dslx_module_member_from_function(
struct xls_dslx_function* function);

struct xls_dslx_module_member *xls_dslx_module_member_from_quickcheck(
struct xls_dslx_quickcheck* quickcheck);

// Returns whether the given DSLX function is parametric.
bool xls_dslx_function_is_parametric(struct xls_dslx_function*);

Expand Down Expand Up @@ -281,12 +311,6 @@ struct xls_dslx_function* xls_dslx_call_graph_get_callee_function(
struct xls_dslx_call_graph* call_graph, struct xls_dslx_function* caller,
int64_t callee_index);

// Returns the QuickCheck AST node from the given module member. The caller
// should ensure the module member kind is
// `xls_dslx_module_member_kind_quick_check`.
struct xls_dslx_quickcheck* xls_dslx_module_member_get_quickcheck(
struct xls_dslx_module_member*);

// Retrieves the underlying function associated with the given QuickCheck.
struct xls_dslx_function* xls_dslx_quickcheck_get_function(
struct xls_dslx_quickcheck*);
Expand Down
7 changes: 7 additions & 0 deletions xls/public/c_api_symbols.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,12 @@ xls_dslx_module_get_type_definition_as_struct_def
xls_dslx_module_get_type_definition_as_type_alias
xls_dslx_module_get_type_definition_count
xls_dslx_module_get_type_definition_kind
xls_dslx_module_member_from_constant_def
xls_dslx_module_member_from_enum_def
xls_dslx_module_member_from_function
xls_dslx_module_member_from_quickcheck
xls_dslx_module_member_from_struct_def
xls_dslx_module_member_from_type_alias
xls_dslx_module_member_get_constant_def
xls_dslx_module_member_get_enum_def
xls_dslx_module_member_get_function
Expand Down Expand Up @@ -231,6 +237,7 @@ xls_dslx_type_ref_get_type_definition
xls_dslx_type_ref_type_annotation_get_type_ref
xls_dslx_type_to_string
xls_dslx_typechecked_module_clone_removing_functions
xls_dslx_typechecked_module_clone_removing_members
xls_dslx_typechecked_module_free
xls_dslx_typechecked_module_get_module
xls_dslx_typechecked_module_get_type_info
Expand Down
63 changes: 55 additions & 8 deletions xls/public/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1625,6 +1625,9 @@ TEST(XlsCApiTest, DslxModuleMembers) {
xls_dslx_module_get_member(module, 0);
xls_dslx_struct_def* struct_def =
xls_dslx_module_member_get_struct_def(struct_def_member);
xls_dslx_module_member* struct_def_via_from =
xls_dslx_module_member_from_struct_def(struct_def);
EXPECT_EQ(struct_def_via_from, struct_def_member);
char* struct_def_identifier =
xls_dslx_struct_def_get_identifier(struct_def);
absl::Cleanup free_struct_def_identifier(
Expand All @@ -1638,6 +1641,9 @@ TEST(XlsCApiTest, DslxModuleMembers) {
xls_dslx_module_get_member(module, 1);
xls_dslx_enum_def* enum_def =
xls_dslx_module_member_get_enum_def(enum_def_member);
xls_dslx_module_member* enum_def_via_from =
xls_dslx_module_member_from_enum_def(enum_def);
EXPECT_EQ(enum_def_via_from, enum_def_member);
char* enum_def_identifier = xls_dslx_enum_def_get_identifier(enum_def);
absl::Cleanup free_enum_def_identifier(
[&] { xls_c_str_free(enum_def_identifier); });
Expand All @@ -1650,6 +1656,9 @@ TEST(XlsCApiTest, DslxModuleMembers) {
xls_dslx_module_get_member(module, 2);
xls_dslx_type_alias* type_alias =
xls_dslx_module_member_get_type_alias(type_alias_member);
xls_dslx_module_member* type_alias_via_from =
xls_dslx_module_member_from_type_alias(type_alias);
EXPECT_EQ(type_alias_via_from, type_alias_member);
char* type_alias_identifier =
xls_dslx_type_alias_get_identifier(type_alias);
absl::Cleanup free_type_alias_identifier(
Expand All @@ -1663,6 +1672,9 @@ TEST(XlsCApiTest, DslxModuleMembers) {
xls_dslx_module_get_member(module, 3);
xls_dslx_constant_def* constant_def =
xls_dslx_module_member_get_constant_def(constant_def_member);
xls_dslx_module_member* constant_def_via_from =
xls_dslx_module_member_from_constant_def(constant_def);
EXPECT_EQ(constant_def_via_from, constant_def_member);
char* constant_def_name = xls_dslx_constant_def_get_name(constant_def);
absl::Cleanup free_constant_def_name(
[&] { xls_c_str_free(constant_def_name); });
Expand Down Expand Up @@ -1715,7 +1727,8 @@ fn main(x: u32) -> u32 {
absl::Cleanup free_tm([=] { xls_dslx_typechecked_module_free(tm); });

xls_dslx_module* module = xls_dslx_typechecked_module_get_module(tm);
auto find_function = [&](std::string_view target) -> xls_dslx_function* {
auto find_function_member =
[&](std::string_view target) -> xls_dslx_module_member* {
int64_t member_count = xls_dslx_module_get_member_count(module);
for (int64_t i = 0; i < member_count; ++i) {
xls_dslx_module_member* member = xls_dslx_module_get_member(module, i);
Expand All @@ -1726,20 +1739,24 @@ fn main(x: u32) -> u32 {
char* identifier = xls_dslx_function_get_identifier(fn);
absl::Cleanup free_identifier([&] { xls_c_str_free(identifier); });
if (std::string_view{identifier} == target) {
return fn;
return member;
}
}
return nullptr;
};

xls_dslx_function* unused_fn = find_function("unused");
xls_dslx_module_member* unused_member = find_function_member("unused");
ASSERT_NE(unused_member, nullptr);
xls_dslx_function* unused_fn =
xls_dslx_module_member_get_function(unused_member);
ASSERT_NE(unused_fn, nullptr);
EXPECT_EQ(xls_dslx_module_member_from_function(unused_fn), unused_member);

xls_dslx_function* removed[] = {unused_fn};
xls_dslx_module_member* removed[] = {unused_member};
xls_dslx_typechecked_module* cloned_tm = nullptr;
ASSERT_TRUE(xls_dslx_typechecked_module_clone_removing_functions(
tm, removed, ABSL_ARRAYSIZE(removed), "top_clone", import_data, &error,
&cloned_tm));
ASSERT_TRUE(xls_dslx_typechecked_module_clone_removing_members(
tm, removed, ABSL_ARRAYSIZE(removed), "top_clone_members", import_data,
&error, &cloned_tm));
ASSERT_EQ(error, nullptr);
absl::Cleanup free_cloned_tm(
[=] { xls_dslx_typechecked_module_free(cloned_tm); });
Expand All @@ -1752,12 +1769,41 @@ fn main(x: u32) -> u32 {
xls_dslx_function* helper_fn =
xls_dslx_module_member_get_function(first_member);
ASSERT_NE(helper_fn, nullptr);
EXPECT_EQ(xls_dslx_module_member_from_function(helper_fn), first_member);
char* helper_name = xls_dslx_function_get_identifier(helper_fn);
absl::Cleanup free_helper_name([&] { xls_c_str_free(helper_name); });
EXPECT_EQ(std::string_view{helper_name}, "helper");
char* module_name = xls_dslx_module_get_name(cloned_module);
absl::Cleanup free_module_name([&] { xls_c_str_free(module_name); });
EXPECT_EQ(std::string_view{module_name}, "top_clone");
EXPECT_EQ(std::string_view{module_name}, "top_clone_members");

xls_dslx_function* removed_functions[] = {unused_fn};
xls_dslx_typechecked_module* cloned_tm_functions = nullptr;
ASSERT_TRUE(xls_dslx_typechecked_module_clone_removing_functions(
tm, removed_functions, ABSL_ARRAYSIZE(removed_functions),
"top_clone_functions", import_data, &error, &cloned_tm_functions));
ASSERT_EQ(error, nullptr);
absl::Cleanup free_cloned_tm_functions(
[=] { xls_dslx_typechecked_module_free(cloned_tm_functions); });

xls_dslx_module* cloned_module_functions =
xls_dslx_typechecked_module_get_module(cloned_tm_functions);
EXPECT_EQ(xls_dslx_module_get_member_count(cloned_module_functions), 2);
xls_dslx_module_member* first_member_functions =
xls_dslx_module_get_member(cloned_module_functions, 0);
xls_dslx_function* helper_fn_functions =
xls_dslx_module_member_get_function(first_member_functions);
ASSERT_NE(helper_fn_functions, nullptr);
char* helper_name_functions =
xls_dslx_function_get_identifier(helper_fn_functions);
absl::Cleanup free_helper_name_functions(
[&] { xls_c_str_free(helper_name_functions); });
EXPECT_EQ(std::string_view{helper_name_functions}, "helper");
char* module_name_functions =
xls_dslx_module_get_name(cloned_module_functions);
absl::Cleanup free_module_name_functions(
[&] { xls_c_str_free(module_name_functions); });
EXPECT_EQ(std::string_view{module_name_functions}, "top_clone_functions");
}

TEST(XlsCApiTest, DslxCloneTypecheckedModuleRemovingMembersFailure) {
Expand Down Expand Up @@ -3238,6 +3284,7 @@ fn prop(x: u8) -> bool {
// Retrieve the QuickCheck node.
xls_dslx_quickcheck* qc = xls_dslx_module_member_get_quickcheck(member);
ASSERT_NE(qc, nullptr);
EXPECT_EQ(xls_dslx_module_member_from_quickcheck(qc), member);

// Inspect the associated function.
xls_dslx_function* fn = xls_dslx_quickcheck_get_function(qc);
Expand Down