Skip to content

Commit 33fe906

Browse files
meheffernancopybara-github
authored andcommitted
Add channel matcher for proc-scoped channels.
This required removing an unused id matcher as channel references don't have ids. One kind of matching no longer works because C++: `send(m::Add(), channel)` where `channel` is `Channel*`. PiperOrigin-RevId: 686295132
1 parent 6e52a4e commit 33fe906

File tree

9 files changed

+142
-116
lines changed

9 files changed

+142
-116
lines changed

xls/ir/channel.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,4 +304,11 @@ std::string ChannelReference::ToString() const {
304304
absl::StrJoin(keyword_strs, " "));
305305
}
306306

307+
std::string ChannelRefToString(ChannelRef ref) {
308+
if (std::holds_alternative<ChannelReference*>(ref)) {
309+
return std::get<ChannelReference*>(ref)->ToString();
310+
}
311+
return std::get<Channel*>(ref)->ToString();
312+
}
313+
307314
} // namespace xls

xls/ir/channel.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,7 @@ std::string_view ChannelRefName(ChannelRef ref);
449449
Type* ChannelRefType(ChannelRef ref);
450450
ChannelKind ChannelRefKind(ChannelRef ref);
451451
std::optional<ChannelStrictness> ChannelRefStrictness(ChannelRef ref);
452+
std::string ChannelRefToString(ChannelRef ref);
452453

453454
} // namespace xls
454455

xls/ir/ir_matcher.cc

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include "xls/ir/node.h"
4141
#include "xls/ir/nodes.h"
4242
#include "xls/ir/op.h"
43+
#include "xls/ir/proc.h"
4344
#include "xls/ir/type.h"
4445
#include "xls/ir/value.h"
4546

@@ -65,33 +66,25 @@ void NameMatcherInternal::DescribeNegationTo(std::ostream* os) const {
6566
} // namespace internal
6667

6768
bool ChannelMatcher::MatchAndExplain(
68-
const ::xls::Channel* channel,
69-
::testing::MatchResultListener* listener) const {
70-
if (channel == nullptr) {
71-
return false;
72-
}
73-
*listener << channel->ToString();
74-
if (id_.has_value() && channel->id() != id_.value()) {
75-
*listener << absl::StreamFormat(" has incorrect id (%d), expected: %d",
76-
channel->id(), id_.value());
77-
return false;
78-
}
69+
::xls::ChannelRef channel, ::testing::MatchResultListener* listener) const {
70+
*listener << ChannelRefToString(channel);
7971

8072
if (name_.has_value() &&
81-
!name_->MatchAndExplain(std::string{channel->name()}, listener)) {
73+
!name_->MatchAndExplain(std::string{ChannelRefName(channel)}, listener)) {
8274
return false;
8375
}
8476

85-
if (kind_.has_value() && channel->kind() != kind_.value()) {
86-
*listener << absl::StreamFormat(" has incorrect kind (%s), expected: %s",
87-
ChannelKindToString(channel->kind()),
88-
ChannelKindToString(kind_.value()));
77+
if (kind_.has_value() && ChannelRefKind(channel) != kind_.value()) {
78+
*listener << absl::StreamFormat(
79+
" has incorrect kind (%s), expected: %s",
80+
ChannelKindToString(ChannelRefKind(channel)),
81+
ChannelKindToString(kind_.value()));
8982
return false;
9083
}
9184
if (type_string_.has_value() &&
92-
channel->type()->ToString() != type_string_.value()) {
85+
ChannelRefType(channel)->ToString() != type_string_.value()) {
9386
*listener << absl::StreamFormat(" has incorrect type (%s), expected: %s",
94-
channel->type()->ToString(),
87+
ChannelRefType(channel)->ToString(),
9588
type_string_.value());
9689
return false;
9790
}
@@ -100,9 +93,6 @@ bool ChannelMatcher::MatchAndExplain(
10093

10194
void ChannelMatcher::DescribeTo(::std::ostream* os) const {
10295
std::vector<std::string> pieces;
103-
if (id_.has_value()) {
104-
pieces.push_back(absl::StrFormat("id=%d", id_.value()));
105-
}
10696
if (name_.has_value()) {
10797
std::stringstream ss;
10898
ss << "name=\"";
@@ -382,16 +372,17 @@ bool TupleIndexMatcher::MatchAndExplain(
382372
}
383373

384374
static bool MatchChannel(
385-
std::string_view channel, Package* package,
386-
const ::testing::Matcher<const ::xls::Channel*>& channel_matcher,
375+
std::string_view channel, ::xls::Proc* proc, ::xls::Direction direction,
376+
const ::testing::Matcher<::xls::ChannelRef>& channel_matcher,
387377
::testing::MatchResultListener* listener) {
388-
absl::StatusOr<::xls::Channel*> channel_status = package->GetChannel(channel);
378+
absl::StatusOr<::xls::ChannelRef> channel_status =
379+
proc->GetChannelRef(channel, direction);
389380
if (!channel_status.ok()) {
390381
*listener << " has an invalid channel name: " << channel;
391382
return false;
392383
}
393-
::xls::Channel* ch = channel_status.value();
394-
return channel_matcher.MatchAndExplain(ch, listener);
384+
::xls::ChannelRef ch_ref = channel_status.value();
385+
return channel_matcher.MatchAndExplain(ch_ref, listener);
395386
}
396387

397388
static std::string_view GetChannelName(const Node* node) {
@@ -413,7 +404,18 @@ bool ChannelNodeMatcher::MatchAndExplain(
413404
if (!channel_matcher_.has_value()) {
414405
return true;
415406
}
416-
return MatchChannel(GetChannelName(node), node->package(),
407+
Direction direction;
408+
if (node->Is<::xls::Send>()) {
409+
direction = Direction::kSend;
410+
} else if (node->Is<::xls::Receive>()) {
411+
direction = Direction::kReceive;
412+
} else {
413+
LOG(FATAL) << absl::StrFormat(
414+
"Expected send or receive node, got node `%s` with op `%s`",
415+
node->GetName(), OpToString(node->op()));
416+
}
417+
return MatchChannel(GetChannelName(node),
418+
node->function_base()->AsProcOrDie(), direction,
417419
channel_matcher_.value(), listener);
418420
}
419421

xls/ir/ir_matcher.h

Lines changed: 40 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -653,97 +653,82 @@ inline ::testing::Matcher<const ::xls::Node*> TupleIndex(
653653
// which communicate over channels (e.g., send and receive). Supported forms:
654654
//
655655
// m::Channel(/*name=*/"foo");
656-
// m::Channel(/*id=*/42);
657-
// m::Channel(ChannelKind::kPort);
656+
// m::Channel(ChannelKind::kStreaming);
658657
// m::Channel(node->GetType());
659658
// m::ChannelWithType("bits[32]");
660659
//
661-
class ChannelMatcher
662-
: public ::testing::MatcherInterface<const ::xls::Channel*> {
660+
class ChannelMatcher : public ::testing::MatcherInterface<::xls::ChannelRef> {
663661
public:
664-
ChannelMatcher(std::optional<int64_t> id,
665-
std::optional<::testing::Matcher<std::string>> name,
662+
ChannelMatcher(std::optional<::testing::Matcher<std::string>> name,
666663
std::optional<ChannelKind> kind,
667664
std::optional<std::string_view> type_string)
668-
: id_(id),
669-
name_(std::move(name)),
670-
kind_(kind),
671-
type_string_(type_string) {}
665+
: name_(std::move(name)), kind_(kind), type_string_(type_string) {}
672666

673-
bool MatchAndExplain(const ::xls::Channel* channel,
667+
bool MatchAndExplain(::xls::ChannelRef channel,
674668
::testing::MatchResultListener* listener) const override;
675669

676670
void DescribeTo(::std::ostream* os) const override;
677671

678672
protected:
679-
std::optional<int64_t> id_;
680673
std::optional<::testing::Matcher<std::string>> name_;
681674
std::optional<ChannelKind> kind_;
682675
std::optional<std::string> type_string_;
683676
};
684677

685-
inline ::testing::Matcher<const ::xls::Channel*> Channel() {
686-
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
687-
std::nullopt, std::nullopt, std::nullopt, std::nullopt));
688-
}
689-
690-
inline ::testing::Matcher<const ::xls::Channel*> Channel(
691-
std::optional<int64_t> id) {
678+
inline ::testing::Matcher<::xls::ChannelRef> Channel() {
692679
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
693-
id, std::nullopt, std::nullopt, std::nullopt));
680+
std::nullopt, std::nullopt, std::nullopt));
694681
}
695682

696683
template <typename T>
697-
inline ::testing::Matcher<const ::xls::Channel*> Channel(
698-
std::optional<int64_t> id, T name,
699-
std::optional<ChannelKind> kind = std::nullopt,
700-
std::optional<const ::xls::Type*> type_ = std::nullopt)
684+
inline ::testing::Matcher<::xls::ChannelRef> Channel(
685+
T name, std::optional<ChannelKind> kind,
686+
std::optional<const ::xls::Type*> type_)
701687
requires(std::is_convertible_v<T, std::string>)
702688
{
703689
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
704-
id, std::string{name}, kind,
690+
std::string{name}, kind,
705691
type_.has_value() ? std::optional(type_.value()->ToString())
706692
: std::nullopt));
707693
}
708694

709695
template <typename T>
710-
inline ::testing::Matcher<const ::xls::Channel*> Channel(T name)
696+
inline ::testing::Matcher<::xls::ChannelRef> Channel(T name)
711697
requires(std::is_convertible_v<T, std::string_view>)
712698
{
713699
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
714-
std::nullopt, internal::NameMatcherInternal(std::string_view{name}),
715-
std::nullopt, std::nullopt));
700+
internal::NameMatcherInternal(std::string_view{name}), std::nullopt,
701+
std::nullopt));
716702
}
717703

718-
inline ::testing::Matcher<const ::xls::Channel*> Channel(
704+
inline ::testing::Matcher<::xls::ChannelRef> Channel(
719705
const ::testing::Matcher<std::string>& matcher) {
720706
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
721-
std::nullopt, matcher, std::nullopt, std::nullopt));
707+
matcher, std::nullopt, std::nullopt));
722708
}
723709

724-
inline ::testing::Matcher<const ::xls::Channel*> Channel(ChannelKind kind) {
725-
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
726-
std::nullopt, std::nullopt, kind, std::nullopt));
710+
inline ::testing::Matcher<::xls::ChannelRef> Channel(ChannelKind kind) {
711+
return ::testing::MakeMatcher(
712+
new ::xls::op_matchers::ChannelMatcher(std::nullopt, kind, std::nullopt));
727713
}
728714

729-
inline ::testing::Matcher<const ::xls::Channel*> Channel(
730-
const ::xls::Type* type_) {
715+
inline ::testing::Matcher<::xls::ChannelRef> Channel(const ::xls::Type* type_) {
731716
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
732-
std::nullopt, std::nullopt, std::nullopt, type_->ToString()));
717+
std::nullopt, std::nullopt, type_->ToString()));
733718
}
734719

735-
inline ::testing::Matcher<const ::xls::Channel*> ChannelWithType(
720+
inline ::testing::Matcher<::xls::ChannelRef> ChannelWithType(
736721
std::string_view type_string) {
737722
return ::testing::MakeMatcher(new ::xls::op_matchers::ChannelMatcher(
738-
std::nullopt, std::nullopt, std::nullopt, type_string));
723+
std::nullopt, std::nullopt, type_string));
739724
}
740725

741726
// Abstract base class for matchers of nodes which use channels.
742727
class ChannelNodeMatcher : public NodeMatcher {
743728
public:
744729
ChannelNodeMatcher(
745730
Op op, absl::Span<const ::testing::Matcher<const Node*>> operands,
746-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
731+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
747732
: NodeMatcher(op, operands),
748733
channel_matcher_(std::move(channel_matcher)) {}
749734

@@ -752,50 +737,50 @@ class ChannelNodeMatcher : public NodeMatcher {
752737
void DescribeTo(::std::ostream* os) const override;
753738

754739
private:
755-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher_;
740+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher_;
756741
};
757742

758743
// Send matcher. Supported forms:
759744
//
760745
// EXPECT_THAT(foo, m::Send());
761-
// EXPECT_THAT(foo, m::Send(m::Channel(42)));
746+
// EXPECT_THAT(foo, m::Send(m::Channel("foo")));
762747
// EXPECT_THAT(foo, m::Send(/*token=*/m::Param(), /*data=*/m::Param(),
763-
// m::Channel(42)));
748+
// m::Channel("bar")));
764749
// EXPECT_THAT(foo, m::Send(/*token=*/m::Param(), /*data=*/m::Param(),
765750
// /*predicate=*/m::Param(),
766-
// m::Channel(42)));
751+
// m::Channel("bazz")));
767752
class SendMatcher : public ChannelNodeMatcher {
768753
public:
769754
explicit SendMatcher(
770-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
755+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
771756
: ChannelNodeMatcher(Op::kSend, {}, std::move(channel_matcher)) {}
772757
explicit SendMatcher(
773758
::testing::Matcher<const Node*> token,
774759
::testing::Matcher<const Node*> data,
775-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
760+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
776761
: ChannelNodeMatcher(Op::kSend, {std::move(token), std::move(data)},
777762
std::move(channel_matcher)) {}
778763
explicit SendMatcher(
779764
::testing::Matcher<const Node*> token,
780765
::testing::Matcher<const Node*> data,
781766
::testing::Matcher<const Node*> predicate,
782-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
767+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
783768
: ChannelNodeMatcher(
784769
Op::kSend,
785770
{std::move(token), std::move(data), std::move(predicate)},
786771
std::move(channel_matcher)) {}
787772
};
788773

789774
inline ::testing::Matcher<const ::xls::Node*> Send(
790-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
775+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
791776
std::nullopt) {
792777
return ::xls::op_matchers::SendMatcher(std::move(channel_matcher));
793778
}
794779

795780
inline ::testing::Matcher<const ::xls::Node*> Send(
796781
::testing::Matcher<const ::xls::Node*> token,
797782
::testing::Matcher<const Node*> data,
798-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
783+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
799784
std::nullopt) {
800785
return ::xls::op_matchers::SendMatcher(std::move(token), std::move(data),
801786
std::move(channel_matcher));
@@ -805,7 +790,7 @@ inline ::testing::Matcher<const ::xls::Node*> Send(
805790
::testing::Matcher<const ::xls::Node*> token,
806791
::testing::Matcher<const Node*> data,
807792
::testing::Matcher<const Node*> predicate,
808-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
793+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
809794
std::nullopt) {
810795
return ::xls::op_matchers::SendMatcher(std::move(token), std::move(data),
811796
std::move(predicate),
@@ -822,31 +807,31 @@ inline ::testing::Matcher<const ::xls::Node*> Send(
822807
class ReceiveMatcher : public ChannelNodeMatcher {
823808
public:
824809
explicit ReceiveMatcher(
825-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
810+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
826811
: ChannelNodeMatcher(Op::kReceive, {}, std::move(channel_matcher)) {}
827812
explicit ReceiveMatcher(
828813
::testing::Matcher<const Node*> token,
829-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
814+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
830815
: ChannelNodeMatcher(Op::kReceive, {std::move(token)},
831816
std::move(channel_matcher)) {}
832817
explicit ReceiveMatcher(
833818
::testing::Matcher<const Node*> token,
834819
::testing::Matcher<const Node*> predicate,
835-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher)
820+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher)
836821
: ChannelNodeMatcher(Op::kReceive,
837822
{std::move(token), std::move(predicate)},
838823
std::move(channel_matcher)) {}
839824
};
840825

841826
inline ::testing::Matcher<const ::xls::Node*> Receive(
842-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
827+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
843828
std::nullopt) {
844829
return ::xls::op_matchers::ReceiveMatcher(std::move(channel_matcher));
845830
}
846831

847832
inline ::testing::Matcher<const ::xls::Node*> Receive(
848833
::testing::Matcher<const Node*> token,
849-
std::optional<::testing::Matcher<const ::xls::Channel*>> channel_matcher =
834+
std::optional<::testing::Matcher<::xls::ChannelRef>> channel_matcher =
850835
std::nullopt) {
851836
return ::xls::op_matchers::ReceiveMatcher(std::move(token),
852837
std::move(channel_matcher));
@@ -855,7 +840,7 @@ inline ::testing::Matcher<const ::xls::Node*> Receive(
855840
inline ::testing::Matcher<const ::xls::Node*> Receive(
856841
::testing::Matcher<const Node*> token,
857842
::testing::Matcher<const Node*> predicate,
858-
::testing::Matcher<const ::xls::Channel*> channel_matcher) {
843+
::testing::Matcher<::xls::ChannelRef> channel_matcher) {
859844
return ::xls::op_matchers::ReceiveMatcher(
860845
std::move(token), std::move(predicate), channel_matcher);
861846
}

0 commit comments

Comments
 (0)