Skip to content

Commit

Permalink
Add channel matcher for proc-scoped channels.
Browse files Browse the repository at this point in the history
This required removing an unused id matcher as channel references don't have ids. One kind of matching no longer works because C++: `send(m::Add(), channel)` where `channel` is `Channel*`.

PiperOrigin-RevId: 686295132
  • Loading branch information
meheffernan authored and copybara-github committed Oct 16, 2024
1 parent 6e52a4e commit 33fe906
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 116 deletions.
7 changes: 7 additions & 0 deletions xls/ir/channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -304,4 +304,11 @@ std::string ChannelReference::ToString() const {
absl::StrJoin(keyword_strs, " "));
}

std::string ChannelRefToString(ChannelRef ref) {
if (std::holds_alternative<ChannelReference*>(ref)) {
return std::get<ChannelReference*>(ref)->ToString();
}
return std::get<Channel*>(ref)->ToString();
}

} // namespace xls
1 change: 1 addition & 0 deletions xls/ir/channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ std::string_view ChannelRefName(ChannelRef ref);
Type* ChannelRefType(ChannelRef ref);
ChannelKind ChannelRefKind(ChannelRef ref);
std::optional<ChannelStrictness> ChannelRefStrictness(ChannelRef ref);
std::string ChannelRefToString(ChannelRef ref);

} // namespace xls

Expand Down
56 changes: 29 additions & 27 deletions xls/ir/ir_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#include "xls/ir/node.h"
#include "xls/ir/nodes.h"
#include "xls/ir/op.h"
#include "xls/ir/proc.h"
#include "xls/ir/type.h"
#include "xls/ir/value.h"

Expand All @@ -65,33 +66,25 @@ void NameMatcherInternal::DescribeNegationTo(std::ostream* os) const {
} // namespace internal

bool ChannelMatcher::MatchAndExplain(
const ::xls::Channel* channel,
::testing::MatchResultListener* listener) const {
if (channel == nullptr) {
return false;
}
*listener << channel->ToString();
if (id_.has_value() && channel->id() != id_.value()) {
*listener << absl::StreamFormat(" has incorrect id (%d), expected: %d",
channel->id(), id_.value());
return false;
}
::xls::ChannelRef channel, ::testing::MatchResultListener* listener) const {
*listener << ChannelRefToString(channel);

if (name_.has_value() &&
!name_->MatchAndExplain(std::string{channel->name()}, listener)) {
!name_->MatchAndExplain(std::string{ChannelRefName(channel)}, listener)) {
return false;
}

if (kind_.has_value() && channel->kind() != kind_.value()) {
*listener << absl::StreamFormat(" has incorrect kind (%s), expected: %s",
ChannelKindToString(channel->kind()),
ChannelKindToString(kind_.value()));
if (kind_.has_value() && ChannelRefKind(channel) != kind_.value()) {
*listener << absl::StreamFormat(
" has incorrect kind (%s), expected: %s",
ChannelKindToString(ChannelRefKind(channel)),
ChannelKindToString(kind_.value()));
return false;
}
if (type_string_.has_value() &&
channel->type()->ToString() != type_string_.value()) {
ChannelRefType(channel)->ToString() != type_string_.value()) {
*listener << absl::StreamFormat(" has incorrect type (%s), expected: %s",
channel->type()->ToString(),
ChannelRefType(channel)->ToString(),
type_string_.value());
return false;
}
Expand All @@ -100,9 +93,6 @@ bool ChannelMatcher::MatchAndExplain(

void ChannelMatcher::DescribeTo(::std::ostream* os) const {
std::vector<std::string> pieces;
if (id_.has_value()) {
pieces.push_back(absl::StrFormat("id=%d", id_.value()));
}
if (name_.has_value()) {
std::stringstream ss;
ss << "name=\"";
Expand Down Expand Up @@ -382,16 +372,17 @@ bool TupleIndexMatcher::MatchAndExplain(
}

static bool MatchChannel(
std::string_view channel, Package* package,
const ::testing::Matcher<const ::xls::Channel*>& channel_matcher,
std::string_view channel, ::xls::Proc* proc, ::xls::Direction direction,
const ::testing::Matcher<::xls::ChannelRef>& channel_matcher,
::testing::MatchResultListener* listener) {
absl::StatusOr<::xls::Channel*> channel_status = package->GetChannel(channel);
absl::StatusOr<::xls::ChannelRef> channel_status =
proc->GetChannelRef(channel, direction);
if (!channel_status.ok()) {
*listener << " has an invalid channel name: " << channel;
return false;
}
::xls::Channel* ch = channel_status.value();
return channel_matcher.MatchAndExplain(ch, listener);
::xls::ChannelRef ch_ref = channel_status.value();
return channel_matcher.MatchAndExplain(ch_ref, listener);
}

static std::string_view GetChannelName(const Node* node) {
Expand All @@ -413,7 +404,18 @@ bool ChannelNodeMatcher::MatchAndExplain(
if (!channel_matcher_.has_value()) {
return true;
}
return MatchChannel(GetChannelName(node), node->package(),
Direction direction;
if (node->Is<::xls::Send>()) {
direction = Direction::kSend;
} else if (node->Is<::xls::Receive>()) {
direction = Direction::kReceive;
} else {
LOG(FATAL) << absl::StrFormat(
"Expected send or receive node, got node `%s` with op `%s`",
node->GetName(), OpToString(node->op()));
}
return MatchChannel(GetChannelName(node),
node->function_base()->AsProcOrDie(), direction,
channel_matcher_.value(), listener);
}

Expand Down
95 changes: 40 additions & 55 deletions xls/ir/ir_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -653,97 +653,82 @@ inline ::testing::Matcher<const ::xls::Node*> TupleIndex(
// which communicate over channels (e.g., send and receive). Supported forms:
//
// m::Channel(/*name=*/"foo");
// m::Channel(/*id=*/42);
// m::Channel(ChannelKind::kPort);
// m::Channel(ChannelKind::kStreaming);
// m::Channel(node->GetType());
// m::ChannelWithType("bits[32]");
//
class ChannelMatcher
: public ::testing::MatcherInterface<const ::xls::Channel*> {
class ChannelMatcher : public ::testing::MatcherInterface<::xls::ChannelRef> {
public:
ChannelMatcher(std::optional<int64_t> id,
std::optional<::testing::Matcher<std::string>> name,
ChannelMatcher(std::optional<::testing::Matcher<std::string>> name,
std::optional<ChannelKind> kind,
std::optional<std::string_view> type_string)
: id_(id),
name_(std::move(name)),
kind_(kind),
type_string_(type_string) {}
: name_(std::move(name)), kind_(kind), type_string_(type_string) {}

bool MatchAndExplain(const ::xls::Channel* channel,
bool MatchAndExplain(::xls::ChannelRef channel,
::testing::MatchResultListener* listener) const override;

void DescribeTo(::std::ostream* os) const override;

protected:
std::optional<int64_t> id_;
std::optional<::testing::Matcher<std::string>> name_;
std::optional<ChannelKind> kind_;
std::optional<std::string> type_string_;
};

inline ::testing::Matcher<const ::xls::Channel*> Channel() {
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
std::nullopt, std::nullopt, std::nullopt, std::nullopt));
}

inline ::testing::Matcher<const ::xls::Channel*> Channel(
std::optional<int64_t> id) {
inline ::testing::Matcher<::xls::ChannelRef> Channel() {
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
id, std::nullopt, std::nullopt, std::nullopt));
std::nullopt, std::nullopt, std::nullopt));
}

template <typename T>
inline ::testing::Matcher<const ::xls::Channel*> Channel(
std::optional<int64_t> id, T name,
std::optional<ChannelKind> kind = std::nullopt,
std::optional<const ::xls::Type*> type_ = std::nullopt)
inline ::testing::Matcher<::xls::ChannelRef> Channel(
T name, std::optional<ChannelKind> kind,
std::optional<const ::xls::Type*> type_)
requires(std::is_convertible_v<T, std::string>)
{
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
id, std::string{name}, kind,
std::string{name}, kind,
type_.has_value() ? std::optional(type_.value()->ToString())
: std::nullopt));
}

template <typename T>
inline ::testing::Matcher<const ::xls::Channel*> Channel(T name)
inline ::testing::Matcher<::xls::ChannelRef> Channel(T name)
requires(std::is_convertible_v<T, std::string_view>)
{
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
std::nullopt, internal::NameMatcherInternal(std::string_view{name}),
std::nullopt, std::nullopt));
internal::NameMatcherInternal(std::string_view{name}), std::nullopt,
std::nullopt));
}

inline ::testing::Matcher<const ::xls::Channel*> Channel(
inline ::testing::Matcher<::xls::ChannelRef> Channel(
const ::testing::Matcher<std::string>& matcher) {
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
std::nullopt, matcher, std::nullopt, std::nullopt));
matcher, std::nullopt, std::nullopt));
}

inline ::testing::Matcher<const ::xls::Channel*> Channel(ChannelKind kind) {
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
std::nullopt, std::nullopt, kind, std::nullopt));
inline ::testing::Matcher<::xls::ChannelRef> Channel(ChannelKind kind) {
return ::testing::MakeMatcher(
new ::xls::op_matchers::ChannelMatcher(std::nullopt, kind, std::nullopt));
}

inline ::testing::Matcher<const ::xls::Channel*> Channel(
const ::xls::Type* type_) {
inline ::testing::Matcher<::xls::ChannelRef> Channel(const ::xls::Type* type_) {
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
std::nullopt, std::nullopt, std::nullopt, type_->ToString()));
std::nullopt, std::nullopt, type_->ToString()));
}

inline ::testing::Matcher<const ::xls::Channel*> ChannelWithType(
inline ::testing::Matcher<::xls::ChannelRef> ChannelWithType(
std::string_view type_string) {
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
std::nullopt, std::nullopt, std::nullopt, type_string));
std::nullopt, std::nullopt, type_string));
}

// Abstract base class for matchers of nodes which use channels.
class ChannelNodeMatcher : public NodeMatcher {
public:
ChannelNodeMatcher(
Op op, absl::Span<const ::testing::Matcher<const Node*>> operands,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: NodeMatcher(op, operands),
channel_matcher_(std::move(channel_matcher)) {}

Expand All @@ -752,50 +737,50 @@ class ChannelNodeMatcher : public NodeMatcher {
void DescribeTo(::std::ostream* os) const override;

private:
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher_;
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher_;
};

// Send matcher. Supported forms:
//
// EXPECT_THAT(foo, m::Send());
// EXPECT_THAT(foo, m::Send(m::Channel(42)));
// EXPECT_THAT(foo, m::Send(m::Channel("foo")));
// EXPECT_THAT(foo, m::Send(/*token=*/m::Param(), /*data=*/m::Param(),
// m::Channel(42)));
// m::Channel("bar")));
// EXPECT_THAT(foo, m::Send(/*token=*/m::Param(), /*data=*/m::Param(),
// /*predicate=*/m::Param(),
// m::Channel(42)));
// m::Channel("bazz")));
class SendMatcher : public ChannelNodeMatcher {
public:
explicit SendMatcher(
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: ChannelNodeMatcher(Op::kSend, {}, std::move(channel_matcher)) {}
explicit SendMatcher(
::testing::Matcher<const Node*> token,
::testing::Matcher<const Node*> data,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: ChannelNodeMatcher(Op::kSend, {std::move(token), std::move(data)},
std::move(channel_matcher)) {}
explicit SendMatcher(
::testing::Matcher<const Node*> token,
::testing::Matcher<const Node*> data,
::testing::Matcher<const Node*> predicate,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: ChannelNodeMatcher(
Op::kSend,
{std::move(token), std::move(data), std::move(predicate)},
std::move(channel_matcher)) {}
};

inline ::testing::Matcher<const ::xls::Node*> Send(
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
std::nullopt) {
return ::xls::op_matchers::SendMatcher(std::move(channel_matcher));
}

inline ::testing::Matcher<const ::xls::Node*> Send(
::testing::Matcher<const ::xls::Node*> token,
::testing::Matcher<const Node*> data,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
std::nullopt) {
return ::xls::op_matchers::SendMatcher(std::move(token), std::move(data),
std::move(channel_matcher));
Expand All @@ -805,7 +790,7 @@ inline ::testing::Matcher<const ::xls::Node*> Send(
::testing::Matcher<const ::xls::Node*> token,
::testing::Matcher<const Node*> data,
::testing::Matcher<const Node*> predicate,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
std::nullopt) {
return ::xls::op_matchers::SendMatcher(std::move(token), std::move(data),
std::move(predicate),
Expand All @@ -822,31 +807,31 @@ inline ::testing::Matcher<const ::xls::Node*> Send(
class ReceiveMatcher : public ChannelNodeMatcher {
public:
explicit ReceiveMatcher(
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: ChannelNodeMatcher(Op::kReceive, {}, std::move(channel_matcher)) {}
explicit ReceiveMatcher(
::testing::Matcher<const Node*> token,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: ChannelNodeMatcher(Op::kReceive, {std::move(token)},
std::move(channel_matcher)) {}
explicit ReceiveMatcher(
::testing::Matcher<const Node*> token,
::testing::Matcher<const Node*> predicate,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
: ChannelNodeMatcher(Op::kReceive,
{std::move(token), std::move(predicate)},
std::move(channel_matcher)) {}
};

inline ::testing::Matcher<const ::xls::Node*> Receive(
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
std::nullopt) {
return ::xls::op_matchers::ReceiveMatcher(std::move(channel_matcher));
}

inline ::testing::Matcher<const ::xls::Node*> Receive(
::testing::Matcher<const Node*> token,
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
std::nullopt) {
return ::xls::op_matchers::ReceiveMatcher(std::move(token),
std::move(channel_matcher));
Expand All @@ -855,7 +840,7 @@ inline ::testing::Matcher<const ::xls::Node*> Receive(
inline ::testing::Matcher<const ::xls::Node*> Receive(
::testing::Matcher<const Node*> token,
::testing::Matcher<const Node*> predicate,
::testing::Matcher<const ::xls::Channel*> channel_matcher) {
::testing::Matcher<::xls::ChannelRef> channel_matcher) {
return ::xls::op_matchers::ReceiveMatcher(
std::move(token), std::move(predicate), channel_matcher);
}
Expand Down
Loading

0 comments on commit 33fe906

Please sign in to comment.