diff --git a/xls/ir/instantiation.h b/xls/ir/instantiation.h index 561d7a69ef..f937fd576b 100644 --- a/xls/ir/instantiation.h +++ b/xls/ir/instantiation.h @@ -47,6 +47,11 @@ absl::StatusOr StringToInstantiationKind( std::string_view str); std::ostream& operator<<(std::ostream& os, InstantiationKind kind); +template +void AbslStringify(Sink& sink, const InstantiationKind kind) { + absl::Format(&sink, "%s", InstantiationKindToString(kind)); +} + struct InstantiationPort { std::string name; Type* type; diff --git a/xls/ir/ir_matcher.cc b/xls/ir/ir_matcher.cc index 55cb6cec6a..0ab662da89 100644 --- a/xls/ir/ir_matcher.cc +++ b/xls/ir/ir_matcher.cc @@ -20,6 +20,7 @@ #include #include #include +#include #include #include "gmock/gmock.h" @@ -674,14 +675,14 @@ void MinDelayMatcher::DescribeTo(::std::ostream* os) const { bool InstantiationMatcher::MatchAndExplain( const ::xls::Instantiation* instantiation, ::testing::MatchResultListener* listener) const { - *listener << instantiation->name(); if (name_.has_value() && !name_->MatchAndExplain(instantiation->name(), listener)) { return false; } if (kind_.has_value() && *kind_ != instantiation->kind()) { - *listener << " has incorrect kind, expected: " << *kind_; + *listener << absl::StreamFormat("%s has incorrect kind, expected: %v", + instantiation->name(), *kind_); return false; } return true; @@ -715,5 +716,79 @@ void InstantiationMatcher::DescribeNegationTo(std::ostream* os) const { name_str, kind_str); } +bool InstantiationOutputMatcher::MatchAndExplain( + const Node* node, ::testing::MatchResultListener* listener) const { + if (!NodeMatcher::MatchAndExplain(node, listener)) { + return false; + } + if (port_name_.has_value() && + !port_name_->MatchAndExplain( + node->As<::xls::InstantiationOutput>()->port_name(), listener)) { + return false; + } + if (instantiation_.has_value() && + !instantiation_->MatchAndExplain( + node->As<::xls::InstantiationOutput>()->instantiation(), listener)) { + return false; + } + return true; +} + +void InstantiationOutputMatcher::DescribeTo(::std::ostream* os) const { + std::vector additional_fields; + if (port_name_.has_value()) { + std::stringstream ss; + ss << "name=\""; + port_name_->DescribeTo(&ss); + ss << '"'; + additional_fields.push_back(std::move(ss).str()); + } + if (instantiation_.has_value()) { + std::stringstream ss; + ss << "instantiation=\""; + instantiation_->DescribeTo(&ss); + ss << '"'; + additional_fields.push_back(std::move(ss).str()); + } + DescribeToHelper(os, additional_fields); +} + +bool InstantiationInputMatcher::MatchAndExplain( + const Node* node, ::testing::MatchResultListener* listener) const { + if (!NodeMatcher::MatchAndExplain(node, listener)) { + return false; + } + if (name_.has_value() && + !name_->MatchAndExplain(node->As()->port_name(), + listener)) { + return false; + } + if (instantiation_.has_value() && + !instantiation_->MatchAndExplain( + node->As<::xls::InstantiationInput>()->instantiation(), listener)) { + return false; + } + return true; +} + +void InstantiationInputMatcher::DescribeTo(::std::ostream* os) const { + std::vector additional_fields; + if (name_.has_value()) { + std::stringstream ss; + ss << "name=\""; + name_->DescribeTo(&ss); + ss << '"'; + additional_fields.push_back(std::move(ss).str()); + } + if (instantiation_.has_value()) { + std::stringstream ss; + ss << "instantiation=\""; + instantiation_->DescribeTo(&ss); + ss << '"'; + additional_fields.push_back(std::move(ss).str()); + } + DescribeToHelper(os, additional_fields); +} + } // namespace op_matchers } // namespace xls diff --git a/xls/ir/ir_matcher.h b/xls/ir/ir_matcher.h index 183ed11214..984c278665 100644 --- a/xls/ir/ir_matcher.h +++ b/xls/ir/ir_matcher.h @@ -1385,25 +1385,28 @@ class InstantiationMatcher { std::optional kind_; }; -inline InstantiationMatcher Instantiation( +inline ::testing::Matcher Instantiation( std::optional name = std::nullopt) { - return ::xls::op_matchers::InstantiationMatcher(std::move(name), - std::nullopt); + if (name.has_value()) { + return ::xls::op_matchers::InstantiationMatcher( + internal::NameMatcherInternal(std::move(name).value()), std::nullopt); + } + return ::xls::op_matchers::InstantiationMatcher(std::nullopt, std::nullopt); } -inline InstantiationMatcher Instantiation( +inline ::testing::Matcher Instantiation( std::optional<::testing::Matcher> name = std::nullopt) { return ::xls::op_matchers::InstantiationMatcher(std::move(name), std::nullopt); } -inline InstantiationMatcher Instantiation(std::string name, - InstantiationKind kind) { +inline ::testing::Matcher Instantiation( + std::string name, InstantiationKind kind) { return ::xls::op_matchers::InstantiationMatcher( internal::NameMatcherInternal(std::move(name)), kind); } -inline InstantiationMatcher Instantiation( +inline ::testing::Matcher Instantiation( ::testing::Matcher name, InstantiationKind kind) { return ::xls::op_matchers::InstantiationMatcher(std::move(name), kind); } @@ -1412,6 +1415,105 @@ inline InstantiationMatcher Instantiation(InstantiationKind kind) { return ::xls::op_matchers::InstantiationMatcher(std::nullopt, kind); } +class InstantiationOutputMatcher : public NodeMatcher { + public: + using is_gtest_matcher = void; + + explicit InstantiationOutputMatcher( + std::optional<::testing::Matcher> port_name, + std::optional<::testing::Matcher> + instantiation) + : NodeMatcher(Op::kInstantiationOutput, /*operands=*/{}), + port_name_(std::move(port_name)), + instantiation_(std::move(instantiation)) {} + + bool MatchAndExplain(const Node* node, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(::std::ostream* os) const override; + void DescribeNegationTo(std::ostream* os) const override { + *os << "did not match: "; + DescribeTo(os); + } + + private: + std::optional<::testing::Matcher> port_name_; + std::optional<::testing::Matcher> instantiation_; +}; + +class InstantiationInputMatcher : public NodeMatcher { + public: + using is_gtest_matcher = void; + + explicit InstantiationInputMatcher( + ::testing::Matcher data, + std::optional<::testing::Matcher> name, + std::optional<::testing::Matcher> + instantiation) + : NodeMatcher(Op::kInstantiationInput, + /*operands=*/{std::move(data)}), + name_(std::move(name)), + instantiation_(std::move(instantiation)) {} + + bool MatchAndExplain(const Node* node, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(::std::ostream* os) const override; + void DescribeNegationTo(std::ostream* os) const override { + *os << "did not match: "; + DescribeTo(os); + } + + private: + std::optional<::testing::Matcher> name_; + std::optional<::testing::Matcher> instantiation_; +}; + +inline ::testing::Matcher InstantiationOutput() { + return ::xls::op_matchers::InstantiationOutputMatcher(std::nullopt, + std::nullopt); +} + +inline ::testing::Matcher InstantiationOutput( + const char* port_name, + std::optional<::testing::Matcher> + instantiation = std::nullopt) { + return ::xls::op_matchers::InstantiationOutputMatcher( + internal::NameMatcherInternal(std::string{port_name}), + std::move(instantiation)); +} + +inline ::testing::Matcher InstantiationOutput( + ::testing::Matcher port_name, + std::optional<::testing::Matcher> + instantiation = std::nullopt) { + return ::xls::op_matchers::InstantiationOutputMatcher( + std::move(port_name), std::move(instantiation)); +} + +inline ::xls::op_matchers::InstantiationInputMatcher InstantiationInput( + ::testing::Matcher data = + ::testing::A()) { + return ::xls::op_matchers::InstantiationInputMatcher( + std::move(data), std::nullopt, std::nullopt); +} + +inline ::xls::op_matchers::InstantiationInputMatcher InstantiationInput( + ::testing::Matcher data, + ::testing::Matcher name, + std::optional<::testing::Matcher> + instantiation = std::nullopt) { + return ::xls::op_matchers::InstantiationInputMatcher( + std::move(data), std::move(name), std::move(instantiation)); +} + +inline ::xls::op_matchers::InstantiationInputMatcher InstantiationInput( + ::testing::Matcher data, const char* name, + std::optional<::testing::Matcher> + instantiation = std::nullopt) { + return ::xls::op_matchers::InstantiationInputMatcher( + std::move(data), internal::NameMatcherInternal(std::string{name}), + std::move(instantiation)); +} + } // namespace op_matchers } // namespace xls diff --git a/xls/ir/ir_matcher_test.cc b/xls/ir/ir_matcher_test.cc index c329aa5ffa..02c250befa 100644 --- a/xls/ir/ir_matcher_test.cc +++ b/xls/ir/ir_matcher_test.cc @@ -628,7 +628,7 @@ TEST(IrMatchersTest, InstantiationMatcher) { BValue x = bb.InputPort("x", u32); BValue y = bb.InputPort("y", u32); - bb.InstantiationInput(add0, "a", x); + BValue a = bb.InstantiationInput(add0, "a", x); bb.InstantiationInput(add0, "b", y); BValue x_plus_y = bb.InstantiationOutput(add0, "result"); bb.OutputPort("x_plus_y", x_plus_y); @@ -653,6 +653,31 @@ TEST(IrMatchersTest, InstantiationMatcher) { EXPECT_THAT(Explain(block->GetInstantiations().at(0), m::Instantiation(InstantiationKind::kExtern)), HasSubstr("add0 has incorrect kind, expected: extern")); + + EXPECT_THAT( + block->nodes(), + AllOf( + Contains(m::InstantiationOutput()), + Contains(m::InstantiationOutput("result")), + Contains(m::InstantiationOutput(HasSubstr("res"))), + Contains(m::InstantiationOutput("result", m::Instantiation("add0"))), + Contains(m::InstantiationInput(m::InputPort("x"))), + Contains(m::InstantiationInput(m::InputPort(HasSubstr("x")))), + Contains(m::InstantiationInput(m::InputPort("x"), "a")), + Contains(m::InstantiationInput(m::InputPort("x"), "a", + m::Instantiation("add0"))))); + EXPECT_THAT(a.node(), ::testing::Not(m::InstantiationInput( + m::InputPort("x"), HasSubstr("b"), + m::Instantiation("add0")))); + + EXPECT_THAT(Explain(a.node(), m::InstantiationInput(m::InputPort("y"))), + HasSubstr("x has incorrect name, expected: y.")); + EXPECT_THAT(Explain(a.node(), m::InstantiationInput(m::InputPort("x"), "b")), + HasSubstr("a has incorrect name, expected: b.")); + EXPECT_THAT( + Explain(a.node(), m::InstantiationInput(m::InputPort("x"), "a", + m::Instantiation("add1"))), + HasSubstr("add0 has incorrect name, expected: add1.")); } } // namespace