Skip to content

Commit 93a103b

Browse files
authored
[PT FE] Support None in example (#28398)
### Details: - *Support `None` in example* ### Tickets: - *CVS-156684* --------- Signed-off-by: Maxim Vafin <[email protected]>
1 parent 2c80544 commit 93a103b

File tree

4 files changed

+220
-15
lines changed

4 files changed

+220
-15
lines changed

src/bindings/python/src/openvino/frontend/pytorch/ts_decoder.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
# flake8: noqa
55
# mypy: ignore-errors
66

7+
import inspect
8+
import logging
9+
import typing
10+
import torch
11+
712
from openvino.frontend.pytorch.py_pytorch_frontend import _FrontEndPytorchDecoder as Decoder
813
from openvino.frontend.pytorch.py_pytorch_frontend import _Type as DecoderType
914
from openvino import op, PartialShape, Type as OVType, OVAny
@@ -14,16 +19,12 @@
1419
prepare_example_inputs_and_model,
1520
convert_quantized_tensor,
1621
graph_has_ops,
22+
patch_none_example,
1723
)
1824
from openvino import opset11 as ops
1925
from openvino.frontend.pytorch import quantized, patch_model
2026
from openvino.frontend.pytorch.module_extension import ModuleExtension
2127

22-
import inspect
23-
import logging
24-
import typing
25-
import torch
26-
2728
log = logging.getLogger(__name__)
2829

2930

@@ -133,6 +134,7 @@ def _get_scripted_model(self, pt_module, example_inputs=None, skip_freeze=False,
133134
scripted = torch.jit.script(pt_module)
134135
freeze_by_default = True
135136
else:
137+
pt_module, example_inputs = patch_none_example(pt_module, example_inputs)
136138
input_parameters, input_signature, pt_module, self._input_is_list = prepare_example_inputs_and_model(
137139
example_inputs, input_params, pt_module)
138140

src/bindings/python/src/openvino/frontend/pytorch/utils.py

+107-8
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
# flake8: noqa
55
# mypy: ignore-errors
66

7+
import inspect
8+
import logging
79
import torch
810
import numpy as np
911

1012
from openvino import op, Type as OVType, Shape, Tensor
1113
from openvino import opset11 as ops
1214

15+
log = logging.getLogger(__name__)
16+
1317

1418
def make_constant(*args, **kwargs):
1519
return op.Constant(*args, **kwargs)
@@ -162,6 +166,23 @@ def forward(self, {input_sign}):
162166
"""
163167

164168

169+
def build_wrapper(template, model):
170+
"""
171+
Builds a wrapper around the given model using the provided template.
172+
"""
173+
result = {}
174+
try:
175+
exec(template, result)
176+
177+
wrapped_model = result["ModelWrapper"](model)
178+
wrapped_model.eval()
179+
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
180+
except Exception:
181+
log.error("Failed to build model wrapper.")
182+
wrapped_model = model
183+
return wrapped_model
184+
185+
165186
def process_dict_inputs(inputs, input_params, model):
166187
ordered_inputs = []
167188
for input_name in input_params:
@@ -203,15 +224,8 @@ def process_dict_inputs(inputs, input_params, model):
203224

204225
wrapper_class = wrapper_template.format(input_sign=", ".join(
205226
input_sign_str), example_input=", ".join(input_params_str))
206-
result = {}
207-
try:
208-
exec(wrapper_class, result)
209227

210-
wrapped_model = result["ModelWrapper"](model)
211-
wrapped_model.eval()
212-
# if wrapping failed, it is better to return original model for avoid user confusion regarding error message
213-
except Exception:
214-
wrapped_model = model
228+
wrapped_model = build_wrapper(wrapper_class, model)
215229

216230
return {"example_inputs": [inputs[name] for name in ordered_inputs]}, ordered_inputs, wrapped_model
217231

@@ -265,3 +279,88 @@ def convert_quantized_tensor(qtensor: torch.Tensor, shared_memory: bool):
265279
sub = ops.subtract(convert, zero_point)
266280
return ops.multiply(sub, scale).outputs()
267281
assert False, "Unsupported qscheme"
282+
283+
284+
def process_individual_input(x, x_name):
285+
"""
286+
Processes an individual input and generates a signature,
287+
parameter string, example entry, and a wrap flag.
288+
"""
289+
sign = None
290+
param = None
291+
example_entry = None
292+
to_wrap = False
293+
if isinstance(x, tuple):
294+
internal_input = []
295+
new_tuple = []
296+
index = 0
297+
for v in x:
298+
if v is None:
299+
to_wrap = True
300+
internal_input.append("None")
301+
else:
302+
internal_input.append(f"{x_name}[{index}]")
303+
new_tuple.append(v)
304+
index += 1
305+
param = f"({', '.join(internal_input)},)"
306+
if len(new_tuple) > 0:
307+
example_entry = tuple(new_tuple)
308+
sign = x_name
309+
elif x is None:
310+
to_wrap = True
311+
param = "None"
312+
else:
313+
sign = x_name
314+
param = x_name
315+
example_entry = x
316+
return sign, param, example_entry, to_wrap
317+
318+
319+
def patch_none_example(model: torch.nn.Module, example):
320+
"""
321+
Patches a PyTorch model to handle None values in the input example.
322+
"""
323+
callable_func = getattr(model, "forward", model.__call__)
324+
input_params = inspect.signature(callable_func).parameters
325+
input_signature = list(input_params)
326+
input_sign_str = []
327+
input_params_str = []
328+
to_wrap = False
329+
if isinstance(example, tuple) and len(input_signature) >= len(example):
330+
new_example = []
331+
for i, x in enumerate(example):
332+
x_name = input_signature[i]
333+
sign, param, example_entry, _to_wrap = process_individual_input(x, x_name)
334+
to_wrap = to_wrap or _to_wrap
335+
if sign is not None:
336+
input_sign_str.append(str(input_params[sign]))
337+
input_params_str.append(param)
338+
if example_entry is not None:
339+
new_example.append(example_entry)
340+
if to_wrap:
341+
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
342+
example_input=", ".join(input_params_str))
343+
wrapped_model = build_wrapper(wrapper_class, model)
344+
log.warning("Model has None in the example input. The input "
345+
"with None will be removed from the resulting model.")
346+
return wrapped_model, tuple(new_example)
347+
elif isinstance(example, dict) and len(input_signature) >= len(example):
348+
new_example = {}
349+
input_signature = [s for s in input_signature if s in example]
350+
for x_name in input_signature:
351+
x = example[x_name]
352+
sign, param, example_entry, _to_wrap = process_individual_input(x, x_name)
353+
to_wrap = to_wrap or _to_wrap
354+
if sign is not None:
355+
input_sign_str.append(str(input_params[sign]))
356+
input_params_str.append(f"{x_name}={param}")
357+
if example_entry is not None:
358+
new_example[x_name] = example_entry
359+
if to_wrap:
360+
wrapper_class = wrapper_template.format(input_sign=", ".join(input_sign_str),
361+
example_input=", ".join(input_params_str))
362+
wrapped_model = build_wrapper(wrapper_class, model)
363+
log.warning("Model has None in the example input. The input "
364+
"with None will be removed from the resulting model.")
365+
return wrapped_model, new_example
366+
return model, example

tests/layer_tests/ovc_python_api_tests/test_pytorch.py

+99-1
Original file line numberDiff line numberDiff line change
@@ -1012,6 +1012,100 @@ def forward(self, a, b):
10121012
), "output": "some_name"}
10131013

10141014

1015+
def create_pytorch_module_with_none_example(tmp_dir):
1016+
class PTModel(torch.nn.Module):
1017+
def forward(self, a, b):
1018+
if b is None:
1019+
b = torch.tensor(1., dtype=torch.float32)
1020+
return a + b
1021+
1022+
net = PTModel()
1023+
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
1024+
add = ov.opset10.add(a, np.float32([1.]))
1025+
ref_model = Model([add], [a], "test")
1026+
return net, ref_model, {
1027+
"example_input": (
1028+
torch.tensor([5, 6], dtype=torch.float32),
1029+
None
1030+
),
1031+
"compress_to_fp16": False}
1032+
1033+
1034+
def create_pytorch_module_with_none_dict_example(tmp_dir):
1035+
class PTModel(torch.nn.Module):
1036+
def forward(self, a, b):
1037+
if b is None:
1038+
b = torch.tensor(1., dtype=torch.float32)
1039+
return a + b
1040+
1041+
net = PTModel()
1042+
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
1043+
add = ov.opset10.add(a, np.float32([1.]))
1044+
ref_model = Model([add], [a], "test")
1045+
return net, ref_model, {
1046+
"example_input": {
1047+
"a": torch.tensor([5, 6], dtype=torch.float32),
1048+
"b": None,
1049+
},
1050+
"compress_to_fp16": False}
1051+
1052+
1053+
def create_pytorch_module_with_none_in_tuple(tmp_dir):
1054+
class PTModel(torch.nn.Module):
1055+
def forward(self, a, b):
1056+
x = a[0]
1057+
if a[1] is None:
1058+
x = x + torch.tensor(1., dtype=torch.float32)
1059+
else:
1060+
x = x + a[1]
1061+
if a[2] is None:
1062+
x = x + torch.tensor(1., dtype=torch.float32)
1063+
else:
1064+
x = x + a[2]
1065+
return x + b
1066+
1067+
net = PTModel()
1068+
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
1069+
b = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
1070+
add = ov.opset10.add(a, np.float32([2.]))
1071+
add2 = ov.opset10.add(add, b)
1072+
ref_model = Model([add2], [a, b], "test")
1073+
return net, ref_model, {
1074+
"example_input": {
1075+
"a": (torch.tensor([5, 6], dtype=torch.float32), None, None),
1076+
"b": torch.tensor([5, 6], dtype=torch.float32),
1077+
},
1078+
"compress_to_fp16": False}
1079+
1080+
1081+
def create_pytorch_module_with_none_in_tuple_case2(tmp_dir):
1082+
class PTModel(torch.nn.Module):
1083+
def forward(self, a, b):
1084+
x = a[0]
1085+
if a[1] is None:
1086+
x = x + torch.tensor(1., dtype=torch.float32)
1087+
else:
1088+
x = x + a[1]
1089+
if a[2] is None:
1090+
x = x + torch.tensor(1., dtype=torch.float32)
1091+
else:
1092+
x = x + a[2]
1093+
return x + b
1094+
1095+
net = PTModel()
1096+
a = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
1097+
add = ov.opset10.add(a, np.float32([2.]))
1098+
b = ov.opset10.parameter(PartialShape([-1]), dtype=np.float32)
1099+
add2 = ov.opset10.add(add, b)
1100+
ref_model = Model([add2], [a, b], "test")
1101+
return net, ref_model, {
1102+
"example_input": (
1103+
(torch.tensor([5, 6], dtype=torch.float32), None, None),
1104+
torch.tensor([5, 6], dtype=torch.float32),
1105+
),
1106+
"compress_to_fp16": False}
1107+
1108+
10151109
class TestMoConvertPyTorch(CommonMOConvertTest):
10161110
test_data = [
10171111
'create_pytorch_nn_module_case1',
@@ -1062,7 +1156,11 @@ class TestMoConvertPyTorch(CommonMOConvertTest):
10621156
'create_pytorch_module_with_nested_inputs6',
10631157
'create_pytorch_module_with_nested_list_and_single_input',
10641158
'create_pytorch_module_with_single_input_as_list',
1065-
'create_pytorch_module_with_nested_dict_input'
1159+
'create_pytorch_module_with_nested_dict_input',
1160+
'create_pytorch_module_with_none_example',
1161+
'create_pytorch_module_with_none_dict_example',
1162+
'create_pytorch_module_with_none_in_tuple',
1163+
'create_pytorch_module_with_none_in_tuple_case2',
10661164
]
10671165

10681166
@pytest.mark.parametrize("create_model", test_data)

tools/ovc/openvino/tools/ovc/moc_frontend/pytorch_frontend_utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def get_pytorch_decoder(model, example_inputs, args):
7575
else:
7676
decoder = model
7777
args['input_model'] = decoder
78-
args["example_input"] = inputs
78+
ei = getattr(decoder, "_example_input", None)
79+
if ei is not None:
80+
args["example_input"] = ei
81+
else:
82+
args["example_input"] = inputs
7983

8084
return args
8185

@@ -250,6 +254,8 @@ def to_torch_tensor(tensor):
250254
return tuple(to_torch_tensor(x) for x in tensor)
251255
if isinstance(tensor, dict) and all(isinstance(k, str) for k in tensor.keys()):
252256
return dict((k, to_torch_tensor(x)) for k, x in tensor.items())
257+
if tensor is None:
258+
return None
253259
else:
254260
raise Error("Unexpected type of example_input. Supported types torch.Tensor, np.array or ov.Tensor. "
255261
"Got {}".format(type(tensor)))

0 commit comments

Comments
 (0)