|
4 | 4 |
|
5 | 5 | #pragma once
|
6 | 6 |
|
| 7 | +#include "openvino/pass/pattern/op/op.hpp" |
7 | 8 | #include "openvino/pass/pattern/op/pattern.hpp"
|
8 | 9 |
|
9 | 10 | namespace ov::pass::pattern {
|
@@ -84,37 +85,33 @@ void collect_type_info(std::vector<DiscreteTypeInfo>& type_info_vec) {
|
84 | 85 | collect_type_info<NodeTypeArgs...>(type_info_vec);
|
85 | 86 | }
|
86 | 87 |
|
87 |
| -template <class... NodeTypes, typename TPredicate> |
88 |
| -std::shared_ptr<Node> optional(const OutputVector& inputs, const TPredicate& pred) { |
| 88 | +template <class... NodeTypes, |
| 89 | + typename TPredicate, |
| 90 | + typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr> |
| 91 | +std::shared_ptr<Node> optional(const PatternOps& inputs, const TPredicate& pred, const Attributes& attrs = {}) { |
89 | 92 | std::vector<DiscreteTypeInfo> optional_type_info_vec;
|
90 | 93 | collect_type_info<NodeTypes...>(optional_type_info_vec);
|
91 |
| - return std::make_shared<op::Optional>(optional_type_info_vec, inputs, op::Predicate(pred)); |
92 |
| -} |
93 |
| - |
94 |
| -template <class... NodeTypes, typename TPredicate> |
95 |
| -std::shared_ptr<Node> optional(const Output<Node>& input, const TPredicate& pred) { |
96 |
| - return optional<NodeTypes...>(OutputVector{input}, op::Predicate(pred)); |
| 94 | + return std::make_shared<op::Optional>( |
| 95 | + optional_type_info_vec, |
| 96 | + ov::OutputVector(inputs), |
| 97 | + attrs.empty() ? op::Predicate(pred) : attrs_match(attrs) && op::Predicate(pred)); |
97 | 98 | }
|
98 | 99 |
|
99 | 100 | template <class... NodeTypes,
|
100 | 101 | typename TPredicate,
|
101 |
| - typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate>>* = nullptr> |
102 |
| -std::shared_ptr<Node> optional(const TPredicate& pred) { |
103 |
| - return optional<NodeTypes...>(OutputVector{}, op::Predicate(pred)); |
104 |
| -} |
105 |
| - |
106 |
| -template <class... NodeTypes> |
107 |
| -std::shared_ptr<Node> optional(const OutputVector& inputs) { |
108 |
| - return optional<NodeTypes...>(inputs, op::Predicate()); |
| 102 | + typename std::enable_if_t<std::is_constructible_v<op::Predicate, TPredicate> && |
| 103 | + !std::is_constructible_v<std::vector<PatternOp>, TPredicate>>* = nullptr> |
| 104 | +std::shared_ptr<Node> optional(const TPredicate& pred, const Attributes& attrs = {}) { |
| 105 | + return optional<NodeTypes...>(OutputVector{}, op::Predicate(pred), attrs); |
109 | 106 | }
|
110 | 107 |
|
111 | 108 | template <class... NodeTypes>
|
112 |
| -std::shared_ptr<Node> optional(const Output<Node>& input) { |
113 |
| - return optional<NodeTypes...>(OutputVector{input}, op::Predicate()); |
| 109 | +std::shared_ptr<Node> optional(const PatternOps& inputs = {}, const Attributes& attrs = {}) { |
| 110 | + return optional<NodeTypes...>(inputs, attrs.empty() ? op::Predicate() : attrs_match(attrs)); |
114 | 111 | }
|
115 | 112 |
|
116 | 113 | template <class... NodeTypes>
|
117 |
| -std::shared_ptr<Node> optional() { |
118 |
| - return optional<NodeTypes...>(OutputVector{}, op::Predicate()); |
| 114 | +std::shared_ptr<Node> optional(std::initializer_list<std::pair<const std::string, ov::Any>>&& attrs) { |
| 115 | + return optional<NodeTypes...>(OutputVector{}, attrs); |
119 | 116 | }
|
120 | 117 | } // namespace ov::pass::pattern
|
0 commit comments