Skip to content

Commit

Permalink
Add matchers for instantiation inputs and outputs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 597893506
  • Loading branch information
grebe authored and copybara-github committed Jan 12, 2024
1 parent ce6ab10 commit bfe2596
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 10 deletions.
5 changes: 5 additions & 0 deletions xls/ir/instantiation.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ absl::StatusOr<InstantiationKind> StringToInstantiationKind(
std::string_view str);
std::ostream& operator<<(std::ostream& os, InstantiationKind kind);

template <typename Sink>
void AbslStringify(Sink& sink, const InstantiationKind kind) {
absl::Format(&sink, "%s", InstantiationKindToString(kind));
}

struct InstantiationPort {
std::string name;
Type* type;
Expand Down
79 changes: 77 additions & 2 deletions xls/ir/ir_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <sstream>
#include <string>
#include <string_view>
#include <utility>
#include <vector>

#include "gmock/gmock.h"
Expand Down Expand Up @@ -674,14 +675,14 @@ void MinDelayMatcher::DescribeTo(::std::ostream* os) const {
bool InstantiationMatcher::MatchAndExplain(
const ::xls::Instantiation* instantiation,
::testing::MatchResultListener* listener) const {
*listener << instantiation->name();
if (name_.has_value() &&
!name_->MatchAndExplain(instantiation->name(), listener)) {
return false;
}

if (kind_.has_value() && *kind_ != instantiation->kind()) {
*listener << " has incorrect kind, expected: " << *kind_;
*listener << absl::StreamFormat("%s has incorrect kind, expected: %v",
instantiation->name(), *kind_);
return false;
}
return true;
Expand Down Expand Up @@ -715,5 +716,79 @@ void InstantiationMatcher::DescribeNegationTo(std::ostream* os) const {
name_str, kind_str);
}

bool InstantiationOutputMatcher::MatchAndExplain(
const Node* node, ::testing::MatchResultListener* listener) const {
if (!NodeMatcher::MatchAndExplain(node, listener)) {
return false;
}
if (port_name_.has_value() &&
!port_name_->MatchAndExplain(
node->As<::xls::InstantiationOutput>()->port_name(), listener)) {
return false;
}
if (instantiation_.has_value() &&
!instantiation_->MatchAndExplain(
node->As<::xls::InstantiationOutput>()->instantiation(), listener)) {
return false;
}
return true;
}

void InstantiationOutputMatcher::DescribeTo(::std::ostream* os) const {
std::vector<std::string> additional_fields;
if (port_name_.has_value()) {
std::stringstream ss;
ss << "name=\"";
port_name_->DescribeTo(&ss);
ss << '"';
additional_fields.push_back(std::move(ss).str());
}
if (instantiation_.has_value()) {
std::stringstream ss;
ss << "instantiation=\"";
instantiation_->DescribeTo(&ss);
ss << '"';
additional_fields.push_back(std::move(ss).str());
}
DescribeToHelper(os, additional_fields);
}

bool InstantiationInputMatcher::MatchAndExplain(
const Node* node, ::testing::MatchResultListener* listener) const {
if (!NodeMatcher::MatchAndExplain(node, listener)) {
return false;
}
if (name_.has_value() &&
!name_->MatchAndExplain(node->As<xls::InstantiationInput>()->port_name(),
listener)) {
return false;
}
if (instantiation_.has_value() &&
!instantiation_->MatchAndExplain(
node->As<::xls::InstantiationInput>()->instantiation(), listener)) {
return false;
}
return true;
}

void InstantiationInputMatcher::DescribeTo(::std::ostream* os) const {
std::vector<std::string> additional_fields;
if (name_.has_value()) {
std::stringstream ss;
ss << "name=\"";
name_->DescribeTo(&ss);
ss << '"';
additional_fields.push_back(std::move(ss).str());
}
if (instantiation_.has_value()) {
std::stringstream ss;
ss << "instantiation=\"";
instantiation_->DescribeTo(&ss);
ss << '"';
additional_fields.push_back(std::move(ss).str());
}
DescribeToHelper(os, additional_fields);
}

} // namespace op_matchers
} // namespace xls
116 changes: 109 additions & 7 deletions xls/ir/ir_matcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -1385,25 +1385,28 @@ class InstantiationMatcher {
std::optional<InstantiationKind> kind_;
};

inline InstantiationMatcher Instantiation(
inline ::testing::Matcher<const ::xls::Instantiation*> Instantiation(
std::optional<std::string> name = std::nullopt) {
return ::xls::op_matchers::InstantiationMatcher(std::move(name),
std::nullopt);
if (name.has_value()) {
return ::xls::op_matchers::InstantiationMatcher(
internal::NameMatcherInternal(std::move(name).value()), std::nullopt);
}
return ::xls::op_matchers::InstantiationMatcher(std::nullopt, std::nullopt);
}

inline InstantiationMatcher Instantiation(
inline ::testing::Matcher<const ::xls::Instantiation*> Instantiation(
std::optional<::testing::Matcher<const std::string>> name = std::nullopt) {
return ::xls::op_matchers::InstantiationMatcher(std::move(name),
std::nullopt);
}

inline InstantiationMatcher Instantiation(std::string name,
InstantiationKind kind) {
inline ::testing::Matcher<const ::xls::Instantiation*> Instantiation(
std::string name, InstantiationKind kind) {
return ::xls::op_matchers::InstantiationMatcher(
internal::NameMatcherInternal(std::move(name)), kind);
}

inline InstantiationMatcher Instantiation(
inline ::testing::Matcher<const ::xls::Instantiation*> Instantiation(
::testing::Matcher<const std::string> name, InstantiationKind kind) {
return ::xls::op_matchers::InstantiationMatcher(std::move(name), kind);
}
Expand All @@ -1412,6 +1415,105 @@ inline InstantiationMatcher Instantiation(InstantiationKind kind) {
return ::xls::op_matchers::InstantiationMatcher(std::nullopt, kind);
}

class InstantiationOutputMatcher : public NodeMatcher {
public:
using is_gtest_matcher = void;

explicit InstantiationOutputMatcher(
std::optional<::testing::Matcher<std::string>> port_name,
std::optional<::testing::Matcher<const class Instantiation*>>
instantiation)
: NodeMatcher(Op::kInstantiationOutput, /*operands=*/{}),
port_name_(std::move(port_name)),
instantiation_(std::move(instantiation)) {}

bool MatchAndExplain(const Node* node,
::testing::MatchResultListener* listener) const override;
void DescribeTo(::std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override {
*os << "did not match: ";
DescribeTo(os);
}

private:
std::optional<::testing::Matcher<std::string>> port_name_;
std::optional<::testing::Matcher<const class Instantiation*>> instantiation_;
};

class InstantiationInputMatcher : public NodeMatcher {
public:
using is_gtest_matcher = void;

explicit InstantiationInputMatcher(
::testing::Matcher<const ::xls::Node*> data,
std::optional<::testing::Matcher<std::string>> name,
std::optional<::testing::Matcher<const class Instantiation*>>
instantiation)
: NodeMatcher(Op::kInstantiationInput,
/*operands=*/{std::move(data)}),
name_(std::move(name)),
instantiation_(std::move(instantiation)) {}

bool MatchAndExplain(const Node* node,
::testing::MatchResultListener* listener) const override;
void DescribeTo(::std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override {
*os << "did not match: ";
DescribeTo(os);
}

private:
std::optional<::testing::Matcher<std::string>> name_;
std::optional<::testing::Matcher<const class Instantiation*>> instantiation_;
};

inline ::testing::Matcher<const ::xls::Node*> InstantiationOutput() {
return ::xls::op_matchers::InstantiationOutputMatcher(std::nullopt,
std::nullopt);
}

inline ::testing::Matcher<const ::xls::Node*> InstantiationOutput(
const char* port_name,
std::optional<::testing::Matcher<const ::xls::Instantiation*>>
instantiation = std::nullopt) {
return ::xls::op_matchers::InstantiationOutputMatcher(
internal::NameMatcherInternal(std::string{port_name}),
std::move(instantiation));
}

inline ::testing::Matcher<const ::xls::Node*> InstantiationOutput(
::testing::Matcher<std::string> port_name,
std::optional<::testing::Matcher<const ::xls::Instantiation*>>
instantiation = std::nullopt) {
return ::xls::op_matchers::InstantiationOutputMatcher(
std::move(port_name), std::move(instantiation));
}

inline ::xls::op_matchers::InstantiationInputMatcher InstantiationInput(
::testing::Matcher<const ::xls::Node*> data =
::testing::A<const ::xls::Node*>()) {
return ::xls::op_matchers::InstantiationInputMatcher(
std::move(data), std::nullopt, std::nullopt);
}

inline ::xls::op_matchers::InstantiationInputMatcher InstantiationInput(
::testing::Matcher<const ::xls::Node*> data,
::testing::Matcher<std::string> name,
std::optional<::testing::Matcher<const ::xls::Instantiation*>>
instantiation = std::nullopt) {
return ::xls::op_matchers::InstantiationInputMatcher(
std::move(data), std::move(name), std::move(instantiation));
}

inline ::xls::op_matchers::InstantiationInputMatcher InstantiationInput(
::testing::Matcher<const ::xls::Node*> data, const char* name,
std::optional<::testing::Matcher<const ::xls::Instantiation*>>
instantiation = std::nullopt) {
return ::xls::op_matchers::InstantiationInputMatcher(
std::move(data), internal::NameMatcherInternal(std::string{name}),
std::move(instantiation));
}

} // namespace op_matchers
} // namespace xls

Expand Down
27 changes: 26 additions & 1 deletion xls/ir/ir_matcher_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ TEST(IrMatchersTest, InstantiationMatcher) {
BValue x = bb.InputPort("x", u32);
BValue y = bb.InputPort("y", u32);

bb.InstantiationInput(add0, "a", x);
BValue a = bb.InstantiationInput(add0, "a", x);
bb.InstantiationInput(add0, "b", y);
BValue x_plus_y = bb.InstantiationOutput(add0, "result");
bb.OutputPort("x_plus_y", x_plus_y);
Expand All @@ -653,6 +653,31 @@ TEST(IrMatchersTest, InstantiationMatcher) {
EXPECT_THAT(Explain(block->GetInstantiations().at(0),
m::Instantiation(InstantiationKind::kExtern)),
HasSubstr("add0 has incorrect kind, expected: extern"));

EXPECT_THAT(
block->nodes(),
AllOf(
Contains(m::InstantiationOutput()),
Contains(m::InstantiationOutput("result")),
Contains(m::InstantiationOutput(HasSubstr("res"))),
Contains(m::InstantiationOutput("result", m::Instantiation("add0"))),
Contains(m::InstantiationInput(m::InputPort("x"))),
Contains(m::InstantiationInput(m::InputPort(HasSubstr("x")))),
Contains(m::InstantiationInput(m::InputPort("x"), "a")),
Contains(m::InstantiationInput(m::InputPort("x"), "a",
m::Instantiation("add0")))));
EXPECT_THAT(a.node(), ::testing::Not(m::InstantiationInput(
m::InputPort("x"), HasSubstr("b"),
m::Instantiation("add0"))));

EXPECT_THAT(Explain(a.node(), m::InstantiationInput(m::InputPort("y"))),
HasSubstr("x has incorrect name, expected: y."));
EXPECT_THAT(Explain(a.node(), m::InstantiationInput(m::InputPort("x"), "b")),
HasSubstr("a has incorrect name, expected: b."));
EXPECT_THAT(
Explain(a.node(), m::InstantiationInput(m::InputPort("x"), "a",
m::Instantiation("add1"))),
HasSubstr("add0 has incorrect name, expected: add1."));
}

} // namespace
Expand Down

0 comments on commit bfe2596

Please sign in to comment.