Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions xls/dslx/fmt/ast_fmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3167,7 +3167,10 @@ absl::StatusOr<std::string> AutoFmt(VirtualizableFilesystem& vfs,
FormatDisabler disabler(vfs, comments, *m.fs_path());
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<Module> clone,
CloneModule(m, std::bind_front(&FormatDisabler::operator(), &disabler)));
CloneModule(m, [&](const AstNode* node, Module*,
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
return disabler(node);
}));
return AutoFmt(*clone, comments, text_width);
}

Expand All @@ -3177,7 +3180,10 @@ absl::StatusOr<std::string> AutoFmt(VirtualizableFilesystem& vfs,
FormatDisabler disabler(vfs, comments, contents);
XLS_ASSIGN_OR_RETURN(
std::unique_ptr<Module> clone,
CloneModule(m, std::bind_front(&FormatDisabler::operator(), &disabler)));
CloneModule(m, [&](const AstNode* node, Module*,
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
return disabler(node);
}));
return AutoFmt(*clone, comments, text_width);
}

Expand Down
65 changes: 41 additions & 24 deletions xls/dslx/frontend/ast_cloner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -949,16 +949,19 @@ class AstCloner : public AstNodeVisitor {
XLS_RETURN_IF_ERROR(VisitChildren(n));

XLS_RETURN_IF_ERROR(ReplaceOrVisit(&n->fn()));
old_to_new_[n] = module(n)->Make<TestFunction>(
n->span(), *down_cast<Function*>(old_to_new_.at(&n->fn())));
XLS_ASSIGN_OR_RETURN(Function * new_fn, CastIfNotVerbatim<Function*>(
old_to_new_.at(&n->fn())));
old_to_new_[n] = module(n)->Make<TestFunction>(n->span(), *new_fn);
return absl::OkStatus();
}

absl::Status HandleTestProc(const TestProc* n) override {
XLS_RETURN_IF_ERROR(VisitChildren(n));

old_to_new_[n] = module(n)->Make<TestProc>(
down_cast<Proc*>(old_to_new_.at(n->proc())), n->expected_fail_label());
XLS_ASSIGN_OR_RETURN(Proc * new_proc,
CastIfNotVerbatim<Proc*>(old_to_new_.at(n->proc())));
old_to_new_[n] =
module(n)->Make<TestProc>(new_proc, n->expected_fail_label());
return absl::OkStatus();
}

Expand Down Expand Up @@ -1190,9 +1193,7 @@ class AstCloner : public AstNodeVisitor {
// already been processed.
absl::Status VisitChildren(const AstNode* node) {
for (const auto& child : node->GetChildren(/*want_types=*/true)) {
if (!old_to_new_.contains(child)) {
XLS_RETURN_IF_ERROR(ReplaceOrVisit(child));
}
XLS_RETURN_IF_ERROR(ReplaceOrVisit(child));
}
return absl::OkStatus();
}
Expand All @@ -1201,7 +1202,11 @@ class AstCloner : public AstNodeVisitor {
if (node == nullptr) {
return absl::OkStatus();
}
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement, replacer_(node));
if (old_to_new_.contains(node)) {
return absl::OkStatus();
}
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> replacement,
replacer_(node, module(node), old_to_new_));
if (replacement.has_value()) {
old_to_new_[node] = *replacement;
return absl::OkStatus();
Expand Down Expand Up @@ -1234,17 +1239,20 @@ class AstCloner : public AstNodeVisitor {

} // namespace

std::optional<AstNode*> PreserveTypeDefinitionsReplacer(const AstNode* node) {
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
const AstNode* node, Module* module,
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
if (node->kind() == AstNodeKind::kTypeRef) {
const auto* type_ref = down_cast<const TypeRef*>(node);
return node->owner()->Make<TypeRef>(type_ref->span(),
type_ref->type_definition());
return module->Make<TypeRef>(type_ref->span(), type_ref->type_definition());
}
return std::nullopt;
}

CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {
return [=](const AstNode* node) -> std::optional<AstNode*> {
return [=](const AstNode* node, Module* new_module,
const absl::flat_hash_map<const AstNode*, AstNode*>&)
-> std::optional<AstNode*> {
if (node->kind() == AstNodeKind::kNameRef) {
const auto* name_ref = down_cast<const NameRef*>(node);
if (std::holds_alternative<const NameDef*>(name_ref->name_def()) &&
Expand All @@ -1258,14 +1266,16 @@ CloneReplacer NameRefReplacer(const NameDef* def, Expr* replacement) {

CloneReplacer NameRefReplacer(
const absl::flat_hash_map<const NameDef*, NameDef*>* replacement_defs) {
return [=](const AstNode* original_node) -> std::optional<AstNode*> {
return [=](const AstNode* original_node, Module* new_module,
const absl::flat_hash_map<const AstNode*, AstNode*>&)
-> std::optional<AstNode*> {
if (original_node->kind() == AstNodeKind::kNameRef) {
const auto* original_ref = down_cast<const NameRef*>(original_node);
const AstNode* def = ToAstNode(original_ref->name_def());
if (def->kind() == AstNodeKind::kNameDef) {
const auto it = replacement_defs->find(down_cast<const NameDef*>(def));
if (it != replacement_defs->end()) {
return original_node->owner()->Make<NameRef>(
return new_module->Make<NameRef>(
original_ref->span(), original_ref->identifier(), it->second);
}
}
Expand All @@ -1281,8 +1291,11 @@ CloneAstAndGetAllPairs(const AstNode* root,
if (root->kind() == AstNodeKind::kModule) {
return absl::InvalidArgumentError("Clone a module via 'CloneModule'.");
}
Module* new_module =
target_module.has_value() ? *target_module : root->owner();
absl::flat_hash_map<const AstNode*, AstNode*> empty_old_to_new;
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> root_replacement,
replacer(root));
replacer(root, new_module, empty_old_to_new));
if (root_replacement.has_value()) {
return absl::flat_hash_map<const AstNode*, AstNode*>{
{root, *root_replacement}};
Expand Down Expand Up @@ -1313,15 +1326,19 @@ absl::StatusOr<std::unique_ptr<Module>> CloneModule(const Module& module,
}

CloneReplacer ChainCloneReplacers(CloneReplacer first, CloneReplacer second) {
return [first = std::move(first),
second = std::move(second)](const AstNode* node) mutable
-> absl::StatusOr<std::optional<AstNode*>> {
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> first_result, first(node));
XLS_ASSIGN_OR_RETURN(
std::optional<AstNode*> second_result,
second(first_result.has_value() ? *first_result : node));
return second_result.has_value() ? second_result : first_result;
};
return
[first = std::move(first), second = std::move(second)](
const AstNode* node, Module* module,
const absl::flat_hash_map<const AstNode*, AstNode*>&
old_to_new) mutable -> absl::StatusOr<std::optional<AstNode*>> {
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> first_result,
first(node, module, old_to_new));
XLS_ASSIGN_OR_RETURN(
std::optional<AstNode*> second_result,
second(first_result.has_value() ? *first_result : node, module,
old_to_new));
return second_result.has_value() ? second_result : first_result;
};
}

// Verifies that `node` consists solely of "new" AST nodes and none that are
Expand Down
23 changes: 18 additions & 5 deletions xls/dslx/frontend/ast_cloner.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,21 @@ namespace xls::dslx {
// nodes during a `CloneAst` operations. A replacer can be used to replace
// targeted nodes with something else entirely, or it can just "clone" those
// nodes differently than the default logic.
//
// The replacer is invoked with:
// - the original AST node under consideration
// - the target `Module*` where any new nodes should be created
// - a pointer to the current old->new mapping accumulated so far during clone
using CloneReplacer =
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(const AstNode*)>;
absl::AnyInvocable<absl::StatusOr<std::optional<AstNode*>>(
const AstNode*, Module*,
const absl::flat_hash_map<const AstNode*, AstNode*>&)>;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a replacer somewhere that actually wants to use the old_to_new argument?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to implement one using this. Basically, when we want to create a new node with some name ref, we want the reference points to something that is in the new module instead of the old one. I would love to use this mapping to resolve it.


// This function is directly usable as the `replacer` argument for `CloneAst`
// when a direct clone with no replacements is desired.
inline std::optional<AstNode*> NoopCloneReplacer(const AstNode* original_node) {
inline std::optional<AstNode*> NoopCloneReplacer(
const AstNode* original_node, Module*,
const absl::flat_hash_map<const AstNode*, AstNode*>&) {
return std::nullopt;
}

Expand All @@ -50,8 +59,11 @@ class ObservableCloneReplacer {
explicit ObservableCloneReplacer(bool* flag, CloneReplacer replacer)
: flag_(flag), replacer_(std::move(replacer)) {}

absl::StatusOr<std::optional<AstNode*>> operator()(const AstNode* node) {
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> result, replacer_(node));
absl::StatusOr<std::optional<AstNode*>> operator()(
const AstNode* node, Module* module,
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new) {
XLS_ASSIGN_OR_RETURN(std::optional<AstNode*> result,
replacer_(node, module, old_to_new));
*flag_ |= result.has_value();
return result;
}
Expand All @@ -66,7 +78,8 @@ class ObservableCloneReplacer {
// cloning return types without recursing into cloned definitions which would
// change nominal types.
std::optional<AstNode*> PreserveTypeDefinitionsReplacer(
const AstNode* original_node);
const AstNode* original_node, Module* module,
const absl::flat_hash_map<const AstNode*, AstNode*>& old_to_new);

// Creates a `CloneReplacer` that replaces references to the given `def` with
// the given `replacement`.
Expand Down
Loading