Skip to content

Commit 7306302

Browse files
Evgenya Nugmanovapraaszitikhono
authored
Attributes matching in a pattern, Synth sugar (#29132)
### Details: - *Updates wrap_type, any_input and optional patterns with an option to match attributes of the node* - *Introduces attrs_match `Predicate`* ### Tickets: - *CVS-162245* --------- Signed-off-by: Evgeniia Nugmanova <[email protected]> Co-authored-by: Pawel Raasz <[email protected]> Co-authored-by: Ivan Tikhonov <[email protected]>
1 parent 864b4ff commit 7306302

File tree

20 files changed

+657
-95
lines changed

20 files changed

+657
-95
lines changed

src/bindings/python/src/openvino/passes/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
type_matches,
1515
type_matches_any,
1616
shape_matches,
17+
attrs_match,
1718
)
1819
from openvino._pyopenvino.passes import Serialize, ConstantFolding, VisualizeTree, MakeStateful, LowLatency2, ConvertFP32ToFP16, Version
1920
from openvino.passes.manager import Manager

src/bindings/python/src/pyopenvino/graph/passes/pattern_ops.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "openvino/pass/pattern/op/pattern.hpp"
1919
#include "openvino/pass/pattern/op/wrap_type.hpp"
2020
#include "pyopenvino/core/common.hpp"
21+
#include "pyopenvino/utils/utils.hpp"
2122

2223
static ov::NodeTypeInfo get_type(const std::string& type_name) {
2324
// Supported types: opsetX.OpName or opsetX::OpName
@@ -1014,6 +1015,9 @@ inline void reg_predicates(py::module m) {
10141015
m.def("type_matches", &ov::pass::pattern::type_matches);
10151016
m.def("type_matches_any", &ov::pass::pattern::type_matches_any);
10161017
m.def("shape_matches", &ov::pass::pattern::shape_matches);
1018+
m.def("attrs_match", [](py::object& attrs) {
1019+
return ov::pass::pattern::attrs_match(Common::utils::py_object_to_unordered_any_map(attrs));
1020+
});
10171021
}
10181022

10191023
void reg_passes_pattern_ops(py::module m) {

src/bindings/python/src/pyopenvino/utils/utils.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,21 @@ ov::AnyMap py_object_to_any_map(const py::object& py_obj) {
411411
return return_value;
412412
}
413413

414+
std::unordered_map<std::string, ov::Any> py_object_to_unordered_any_map(const py::object& py_obj) {
415+
OPENVINO_ASSERT(py_object_is_any_map(py_obj), "Unsupported attribute type.");
416+
std::unordered_map<std::string, ov::Any> return_value = {};
417+
for (auto& item : py::cast<py::dict>(py_obj)) {
418+
std::string key = py::cast<std::string>(item.first);
419+
py::object value = py::cast<py::object>(item.second);
420+
if (py_object_is_any_map(value)) {
421+
return_value[key] = Common::utils::py_object_to_any_map(value);
422+
} else {
423+
return_value[key] = Common::utils::py_object_to_any(value);
424+
}
425+
}
426+
return return_value;
427+
}
428+
414429
template <typename... Args, std::size_t... I>
415430
std::tuple<Args...> tuple_from_py_tuple_impl(const py::tuple& py_tuple, std::index_sequence<I...>) {
416431
return std::make_tuple(py_tuple[I].cast<Args>()...);

src/bindings/python/src/pyopenvino/utils/utils.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ class MemoryBuffer : public std::streambuf {
140140

141141
ov::AnyMap py_object_to_any_map(const py::object& py_obj);
142142

143+
std::unordered_map<std::string, ov::Any> py_object_to_unordered_any_map(const py::object& py_obj);
144+
143145
ov::Any py_object_to_any(const py::object& py_obj);
144146

145147
ov::pass::Serialize::Version convert_to_version(const std::string& version);

src/bindings/python/tests/test_transformations/test_pattern_ops.py

+14
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
type_matches,
1717
type_matches_any,
1818
shape_matches,
19+
attrs_match,
1920
)
2021
from openvino.utils.types import get_element_type
2122

@@ -278,6 +279,19 @@ def symbol_matching_test(shape: PartialShape, pattern: str):
278279
assert symbols["Six"] == 6, symbols
279280

280281

282+
def test_attrs_match():
283+
param = ops.parameter([-1, -1])
284+
285+
def test_shape_of_attribute(et: str):
286+
node = ops.shape_of(param, output_type=et)
287+
attr = {"output_type": et}
288+
matcher = Matcher(AnyInput(attrs_match(attr)), "Find shape_of with attribute")
289+
assert matcher.match(node), f"Match failed for {node} with attribute"
290+
291+
test_shape_of_attribute("i64")
292+
test_shape_of_attribute("i32")
293+
294+
281295
def test_optional_full_match():
282296
model_input = ops.parameter(PartialShape.dynamic())
283297
model_abs = ops.abs(model_input)

src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -346,16 +346,12 @@ ov::pass::DeReshapeFullyConnected::DeReshapeFullyConnected() {
346346
using namespace ov::op;
347347
using namespace ov::pass::pattern;
348348

349-
auto transpose_a_false = [](const std::shared_ptr<Node>& node) -> bool {
350-
auto mm = as_type_ptr<v0::MatMul>(node);
351-
return mm && !mm->get_transpose_a();
352-
};
353-
354349
auto input = wrap_type<v1::Reshape>({any_input(shape_matches("BATCHES_1...,Y")), any_input()},
355350
shape_matches("BATCHES_2...,Y"));
356351
auto converted = pattern::optional<v0::Convert>(input, consumers_count(1));
357352
auto mm_label = wrap_type<v0::MatMul>({converted, any_input(rank_equals(2))},
358-
consumers_count(1) && transpose_a_false && shape_matches("BATCHES_2...,Z"));
353+
consumers_count(1) && shape_matches("BATCHES_2...,Z"),
354+
{{"transpose_a", false}});
359355
auto output = wrap_type<v1::Reshape>({mm_label, any_input()}, shape_matches("BATCHES_1...,Z"));
360356

361357
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {

src/core/include/openvino/core/attribute_adapter.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ class IndirectVectorValueAccessor : public ValueAccessor<VAT> {
171171
OPENVINO_THROW("Bad cast from: ", x.type_info().name(), " to: ", typeid(AT).name());
172172
}
173173
}
174+
174175
operator AT&() {
175176
return m_ref;
176177
}

src/core/include/openvino/pass/pattern/op/label.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,11 @@ class OPENVINO_API Label : public Pattern {
7272
};
7373
} // namespace op
7474

75-
OPENVINO_API std::shared_ptr<Node> any_input();
75+
OPENVINO_API std::shared_ptr<Node> any_input(const Attributes& attrs = {});
7676

77-
template <typename TPredicate>
77+
template <typename TPredicate,
78+
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate> &&
79+
!std::is_constructible_v<Attributes, TPredicate>>* = nullptr>
7880
std::shared_ptr<Node> any_input(const TPredicate& pred) {
7981
return std::make_shared<pattern::op::Label>(element::dynamic, PartialShape::dynamic(), op::Predicate(pred));
8082
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/op/constant.hpp"
8+
#include "openvino/pass/pattern/op/or.hpp"
9+
#include "openvino/pass/pattern/op/predicate.hpp"
10+
11+
namespace ov::pass::pattern {
12+
// A glue/syntax-sugar type which allows more types to be used as input to pattern operations
13+
struct OPENVINO_API PatternOp {
14+
private:
15+
std::shared_ptr<ov::Node> op;
16+
int64_t output_idx = -1;
17+
18+
public:
19+
operator ov::Output<ov::Node>() const;
20+
ov::Output<ov::Node> get_output() const;
21+
22+
PatternOp(const Output<Node>& out);
23+
24+
template <typename T, typename std::enable_if_t<std::is_base_of_v<ov::Node, T>>* = nullptr>
25+
PatternOp(const std::shared_ptr<T>& op) : op(std::dynamic_pointer_cast<ov::Node>(op)) {}
26+
27+
PatternOp(const std::shared_ptr<Node>& op);
28+
PatternOp(ov::Rank rank);
29+
30+
// Constant matching
31+
PatternOp(const char* value_notation);
32+
PatternOp(std::string value_notation);
33+
PatternOp(int v);
34+
PatternOp(float v);
35+
PatternOp(double v);
36+
PatternOp(long long v);
37+
38+
PatternOp(std::initializer_list<const char*>&& v);
39+
PatternOp(std::initializer_list<const std::string>&& v);
40+
PatternOp(std::initializer_list<const int>&& v);
41+
PatternOp(std::initializer_list<const float>&& v);
42+
PatternOp(std::initializer_list<const double>&& v);
43+
PatternOp(std::initializer_list<const long long>&& v);
44+
};
45+
46+
// Syntax-sugar type for pattern operators to consume all the different ways to pass containter of inputs with use of
47+
// PatternOp
48+
struct OPENVINO_API PatternOps {
49+
private:
50+
std::vector<PatternOp> data;
51+
52+
public:
53+
PatternOps();
54+
55+
// single element
56+
template <typename T, typename std::enable_if_t<std::is_constructible_v<PatternOp, T>>* = nullptr>
57+
PatternOps(const T& in) : data{PatternOp(in)} {};
58+
PatternOps(const std::shared_ptr<Node>&);
59+
PatternOps(const Output<Node>&);
60+
61+
// multi-element
62+
PatternOps(const OutputVector&);
63+
PatternOps(std::initializer_list<pattern::PatternOp>&&);
64+
65+
explicit operator ov::OutputVector() const;
66+
};
67+
68+
} // namespace ov::pass::pattern

src/core/include/openvino/pass/pattern/op/optional.hpp

+17-20
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#pragma once
66

7+
#include "openvino/pass/pattern/op/op.hpp"
78
#include "openvino/pass/pattern/op/pattern.hpp"
89

910
namespace ov::pass::pattern {
@@ -84,37 +85,33 @@ void collect_type_info(std::vector<DiscreteTypeInfo>& type_info_vec) {
8485
collect_type_info<NodeTypeArgs...>(type_info_vec);
8586
}
8687

87-
template <class... NodeTypes, typename TPredicate>
88-
std::shared_ptr<Node> optional(const OutputVector& inputs, const TPredicate& pred) {
88+
template <class... NodeTypes,
89+
typename TPredicate,
90+
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr>
91+
std::shared_ptr<Node> optional(const PatternOps& inputs, const TPredicate& pred, const Attributes& attrs = {}) {
8992
std::vector<DiscreteTypeInfo> optional_type_info_vec;
9093
collect_type_info<NodeTypes...>(optional_type_info_vec);
91-
return std::make_shared<op::Optional>(optional_type_info_vec, inputs, op::Predicate(pred));
92-
}
93-
94-
template <class... NodeTypes, typename TPredicate>
95-
std::shared_ptr<Node> optional(const Output<Node>& input, const TPredicate& pred) {
96-
return optional<NodeTypes...>(OutputVector{input}, op::Predicate(pred));
94+
return std::make_shared<op::Optional>(
95+
optional_type_info_vec,
96+
ov::OutputVector(inputs),
97+
attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred));
9798
}
9899

99100
template <class... NodeTypes,
100101
typename TPredicate,
101-
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr>
102-
std::shared_ptr<Node> optional(const TPredicate& pred) {
103-
return optional<NodeTypes...>(OutputVector{}, op::Predicate(pred));
104-
}
105-
106-
template <class... NodeTypes>
107-
std::shared_ptr<Node> optional(const OutputVector& inputs) {
108-
return optional<NodeTypes...>(inputs, op::Predicate());
102+
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate> &&
103+
!std::is_constructible_v<std::vector<PatternOp>, TPredicate>>* = nullptr>
104+
std::shared_ptr<Node> optional(const TPredicate& pred, const Attributes& attrs = {}) {
105+
return optional<NodeTypes...>(OutputVector{}, op::Predicate(pred), attrs);
109106
}
110107

111108
template <class... NodeTypes>
112-
std::shared_ptr<Node> optional(const Output<Node>& input) {
113-
return optional<NodeTypes...>(OutputVector{input}, op::Predicate());
109+
std::shared_ptr<Node> optional(const PatternOps& inputs = {}, const Attributes& attrs = {}) {
110+
return optional<NodeTypes...>(inputs, attrs.empty() ? op::Predicate() : attrs_match(attrs));
114111
}
115112

116113
template <class... NodeTypes>
117-
std::shared_ptr<Node> optional() {
118-
return optional<NodeTypes...>(OutputVector{}, op::Predicate());
114+
std::shared_ptr<Node> optional(std::initializer_list<std::pair<const std::string, ov::Any>>&& attrs) {
115+
return optional<NodeTypes...>(OutputVector{}, attrs);
119116
}
120117
} // namespace ov::pass::pattern

src/core/include/openvino/pass/pattern/op/or.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
#include "openvino/core/node.hpp"
88
#include "openvino/pass/pattern/op/pattern.hpp"
99

10-
namespace ov::pass::pattern::op {
10+
namespace ov::pass {
11+
namespace pattern::op {
1112
/// A submatch on the graph value is performed on each input to the Or; the match
1213
/// succeeds on the first match. Otherwise the match fails.
1314
class OPENVINO_API Or : public Pattern {
@@ -22,4 +23,7 @@ class OPENVINO_API Or : public Pattern {
2223
const Output<Node>& pattern_value,
2324
const Output<Node>& graph_value) override;
2425
};
25-
} // namespace ov::pass::pattern::op
26+
} // namespace pattern::op
27+
OPENVINO_API std::shared_ptr<Node> operator|(const Output<Node>& lhs, const Output<Node>& rhs);
28+
OPENVINO_API std::shared_ptr<Node> operator|(const std::shared_ptr<Node>& lhs, const std::shared_ptr<Node>& rhs);
29+
} // namespace ov::pass

src/core/include/openvino/pass/pattern/op/pattern.hpp

+4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ using PatternValueMap = std::map<std::shared_ptr<Node>, Output<Node>>;
2222
using PatternValueMaps = std::vector<PatternValueMap>;
2323

2424
using PatternMap = std::map<std::shared_ptr<Node>, std::shared_ptr<Node>>;
25+
using Attributes = std::unordered_map<std::string, ov::Any>;
2526

2627
PatternMap as_pattern_map(const PatternValueMap& pattern_value_map);
2728
PatternValueMap as_pattern_value_map(const PatternMap& pattern_map);
@@ -61,7 +62,10 @@ OPENVINO_API op::Predicate type_matches_any(const std::vector<element::Type>& ty
6162

6263
OPENVINO_API op::Predicate all_of(const std::vector<std::function<bool(Output<Node>)>>& predicates);
6364

65+
OPENVINO_API op::Predicate attrs_match(const Attributes& expected_attrs);
66+
6467
OPENVINO_API op::Predicate shape_matches(const std::string& shape_notation);
68+
OPENVINO_API op::Predicate value_matches(const std::string& value_notation);
6569

6670
namespace op {
6771
OPENVINO_DEPRECATED("This method is deprecated. Use constructor of ov::pass::pattern::Predicate instead")

src/core/include/openvino/pass/pattern/op/predicate.hpp

+6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@ class OPENVINO_API PatternSymbolValue {
3838
const std::vector<PatternSymbolValue>& g() const;
3939

4040
bool operator==(const PatternSymbolValue& other) const;
41+
bool operator!=(const PatternSymbolValue& other) const;
42+
43+
template <typename T, typename std::enable_if_t<std::is_constructible_v<PatternSymbolValue, T>>* = nullptr>
44+
static std::vector<PatternSymbolValue> make_value_vector(const std::vector<T>& v) {
45+
return {v.begin(), v.end()};
46+
}
4147

4248
private:
4349
bool is_valid() const;

src/core/include/openvino/pass/pattern/op/wrap_type.hpp

+22-8
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "openvino/core/node.hpp"
88
#include "openvino/op/constant.hpp"
9+
#include "openvino/pass/pattern/op/op.hpp"
910
#include "openvino/pass/pattern/op/pattern.hpp"
1011

1112
namespace ov::pass::pattern {
@@ -63,20 +64,33 @@ void collect_wrap_info(std::vector<DiscreteTypeInfo>& info) {
6364
collect_wrap_info<Targs...>(info);
6465
}
6566

66-
template <class... Args, typename TPredicate>
67-
std::shared_ptr<Node> wrap_type(const OutputVector& inputs, const TPredicate& pred) {
67+
template <class... Args,
68+
typename TPredicate,
69+
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr>
70+
std::shared_ptr<Node> wrap_type(const PatternOps& inputs, const TPredicate& pred, const Attributes& attrs = {}) {
6871
std::vector<DiscreteTypeInfo> info;
6972
collect_wrap_info<Args...>(info);
70-
return std::make_shared<op::WrapType>(info, op::Predicate(pred), inputs);
73+
return std::make_shared<op::WrapType>(
74+
info,
75+
(attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred)),
76+
ov::OutputVector(inputs));
77+
}
78+
79+
template <class... Args,
80+
typename TPredicate,
81+
typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate> &&
82+
!std::is_constructible_v<PatternOps, TPredicate>>* = nullptr>
83+
std::shared_ptr<Node> wrap_type(const TPredicate& pred, const Attributes& attrs = {}) {
84+
return wrap_type<Args...>(PatternOps{}, op::Predicate(pred), attrs);
7185
}
7286

7387
template <class... Args>
74-
std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {}) {
75-
return wrap_type<Args...>(inputs, op::Predicate());
88+
std::shared_ptr<Node> wrap_type(const PatternOps& inputs = {}, const Attributes& attrs = {}) {
89+
return wrap_type<Args...>(inputs, (attrs.empty() ? op::Predicate() : attrs_match(attrs)));
7690
}
7791

78-
template <class... Args, typename TPredicate>
79-
std::shared_ptr<Node> wrap_type(const TPredicate& pred) {
80-
return wrap_type<Args...>({}, op::Predicate(pred));
92+
template <class... Args>
93+
std::shared_ptr<Node> wrap_type(std::initializer_list<std::pair<const std::string, ov::Any>>&& attrs) {
94+
return wrap_type<Args...>(PatternOps{}, Attributes(attrs));
8195
}
8296
} // namespace ov::pass::pattern

src/core/src/pattern/op/label.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,6 @@ bool ov::pass::pattern::op::Label::match_value(ov::pass::pattern::Matcher* match
5454
return false;
5555
}
5656

57-
std::shared_ptr<ov::Node> ov::pass::pattern::any_input() {
58-
return std::make_shared<pattern::op::Label>();
57+
std::shared_ptr<ov::Node> ov::pass::pattern::any_input(const Attributes& attrs) {
58+
return attrs.empty() ? std::make_shared<pattern::op::Label>() : any_input(attrs_match(attrs));
5959
}

0 commit comments

Comments
 (0)