@@ -98,7 +98,7 @@ def validate_and_infer_types(self):
98
98
self .set_output_type (0 , self .get_input_element_type (0 ), self .get_input_partial_shape (0 ))
99
99
100
100
def clone_with_new_inputs (self , new_inputs ):
101
- return CustomOpWithAttribute (new_inputs )
101
+ return CustomOpWithAttribute (new_inputs , self . _attrs )
102
102
103
103
def get_type_info (self ):
104
104
return CustomOpWithAttribute .class_type_info
@@ -138,8 +138,7 @@ def prepared_paths(request, tmp_path):
138
138
({"wrong_np" : np .array ([1.5 , 2.5 ], dtype = "complex128" )}, pytest .raises (TypeError ), "Unsupported NumPy array dtype: complex128" ),
139
139
({"wrong" : {}}, pytest .raises (TypeError ), "Unsupported attribute type: <class 'dict'>" )
140
140
])
141
- @pytest .mark .skipif (sys .platform == "win32" , reason = "CVS-164354 BUG: hanged on windows wheels" )
142
- def test_visit_attributes_custom_op (prepared_paths , attributes , expectation , raise_msg ):
141
+ def test_visit_attributes_custom_op (device , prepared_paths , attributes , expectation , raise_msg ):
143
142
input_shape = [2 , 1 ]
144
143
145
144
param1 = ops .parameter (Shape (input_shape ), dtype = np .float32 , name = "data1" )
@@ -165,7 +164,7 @@ def test_visit_attributes_custom_op(prepared_paths, attributes, expectation, rai
165
164
input_data = np .ones ([2 , 1 ], dtype = np .float32 )
166
165
expected_output = np .maximum (0.0 , input_data )
167
166
168
- compiled_model = compile_model (model_with_op_attr )
167
+ compiled_model = compile_model (model_with_op_attr , device )
169
168
input_tensor = Tensor (input_data )
170
169
results = compiled_model ({"data1" : input_tensor })
171
170
assert np .allclose (results [list (results )[0 ]], expected_output , 1e-4 , 1e-4 )
@@ -200,9 +199,9 @@ def test_custom_add_model():
200
199
assert op_types == ["Parameter" , "Parameter" , "CustomAdd" , "Result" ]
201
200
202
201
203
- def test_custom_op ():
202
+ def test_custom_op (device ):
204
203
model = create_snake_model ()
205
- compiled_model = compile_model (model )
204
+ compiled_model = compile_model (model , device )
206
205
207
206
assert isinstance (compiled_model , CompiledModel )
208
207
request = compiled_model .create_infer_request ()
0 commit comments