Skip to content

Commit 53575d8

Browse files
Match transformers behavior with return_dict (#2269)
* fix * fix * more
1 parent dcd5ecb commit 53575d8

File tree

3 files changed

+59
-60
lines changed

3 files changed

+59
-60
lines changed

optimum/onnxruntime/modeling_diffusion.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def forward(
611611
timestep_cond: Optional[Union[np.ndarray, torch.Tensor]] = None,
612612
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613613
added_cond_kwargs: Optional[Dict[str, Any]] = None,
614-
return_dict: bool = False,
614+
return_dict: bool = True,
615615
):
616616
use_torch = isinstance(sample, torch.Tensor)
617617

@@ -631,8 +631,8 @@ def forward(
631631
onnx_outputs = self.session.run(None, onnx_inputs)
632632
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
633633

634-
if return_dict:
635-
return model_outputs
634+
if not return_dict:
635+
return tuple(model_outputs.values())
636636

637637
return ModelOutput(**model_outputs)
638638

@@ -648,7 +648,7 @@ def forward(
648648
txt_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
649649
img_ids: Optional[Union[np.ndarray, torch.Tensor]] = None,
650650
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
651-
return_dict: bool = False,
651+
return_dict: bool = True,
652652
):
653653
use_torch = isinstance(hidden_states, torch.Tensor)
654654

@@ -667,8 +667,8 @@ def forward(
667667
onnx_outputs = self.session.run(None, onnx_inputs)
668668
model_outputs = self.prepare_onnx_outputs(use_torch, *onnx_outputs)
669669

670-
if return_dict:
671-
return model_outputs
670+
if not return_dict:
671+
return tuple(model_outputs.values())
672672

673673
return ModelOutput(**model_outputs)
674674

@@ -679,7 +679,7 @@ def forward(
679679
input_ids: Union[np.ndarray, torch.Tensor],
680680
attention_mask: Optional[Union[np.ndarray, torch.Tensor]] = None,
681681
output_hidden_states: Optional[bool] = None,
682-
return_dict: bool = False,
682+
return_dict: bool = True,
683683
):
684684
use_torch = isinstance(input_ids, torch.Tensor)
685685

@@ -700,8 +700,8 @@ def forward(
700700
for i in range(num_layers):
701701
model_outputs.pop(f"hidden_states.{i}", None)
702702

703-
if return_dict:
704-
return model_outputs
703+
if not return_dict:
704+
return tuple(model_outputs.values())
705705

706706
return ModelOutput(**model_outputs)
707707

@@ -722,7 +722,7 @@ def forward(
722722
self,
723723
sample: Union[np.ndarray, torch.Tensor],
724724
generator: Optional[torch.Generator] = None,
725-
return_dict: bool = False,
725+
return_dict: bool = True,
726726
):
727727
use_torch = isinstance(sample, torch.Tensor)
728728

@@ -740,8 +740,8 @@ def forward(
740740
parameters=model_outputs.pop("latent_parameters")
741741
)
742742

743-
if return_dict:
744-
return model_outputs
743+
if not return_dict:
744+
return tuple(model_outputs.values())
745745

746746
return ModelOutput(**model_outputs)
747747

@@ -762,7 +762,7 @@ def forward(
762762
self,
763763
latent_sample: Union[np.ndarray, torch.Tensor],
764764
generator: Optional[torch.Generator] = None,
765-
return_dict: bool = False,
765+
return_dict: bool = True,
766766
):
767767
use_torch = isinstance(latent_sample, torch.Tensor)
768768

@@ -775,8 +775,8 @@ def forward(
775775
if "latent_sample" in model_outputs:
776776
model_outputs["latents"] = model_outputs.pop("latent_sample")
777777

778-
if return_dict:
779-
return model_outputs
778+
if not return_dict:
779+
return tuple(model_outputs.values())
780780

781781
return ModelOutput(**model_outputs)
782782

optimum/onnxruntime/modeling_ort.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ def forward(
10771077
input_features: Optional[Union[torch.Tensor, np.ndarray]] = None,
10781078
input_values: Optional[Union[torch.Tensor, np.ndarray]] = None,
10791079
*,
1080-
return_dict: bool = False,
1080+
return_dict: bool = True,
10811081
**kwargs,
10821082
):
10831083
# Warn about any unexpected kwargs using the helper method
@@ -1134,8 +1134,8 @@ def forward(
11341134
# TODO: This allows to support sentence-transformers models (sentence embedding), but is not validated.
11351135
last_hidden_state = next(iter(model_outputs.values()))
11361136

1137-
if return_dict:
1138-
return {"last_hidden_state": last_hidden_state}
1137+
if not return_dict:
1138+
return (last_hidden_state,)
11391139

11401140
# converts output to namedtuple for pipelines post-processing
11411141
return BaseModelOutput(last_hidden_state=last_hidden_state)
@@ -1251,7 +1251,7 @@ def forward(
12511251
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
12521252
token_type_ids: Optional[Union[torch.Tensor, np.ndarray]] = None,
12531253
*,
1254-
return_dict: bool = False,
1254+
return_dict: bool = True,
12551255
**kwargs,
12561256
):
12571257
# Warn about any unexpected kwargs using the helper method
@@ -1288,8 +1288,8 @@ def forward(
12881288

12891289
logits = model_outputs["logits"]
12901290

1291-
if return_dict:
1292-
return {"logits": logits}
1291+
if not return_dict:
1292+
return (logits,)
12931293

12941294
# converts output to namedtuple for pipelines post-processing
12951295
return MaskedLMOutput(logits=logits)
@@ -1353,7 +1353,7 @@ def forward(
13531353
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
13541354
token_type_ids: Optional[Union[torch.Tensor, np.ndarray]] = None,
13551355
*,
1356-
return_dict: bool = False,
1356+
return_dict: bool = True,
13571357
**kwargs,
13581358
):
13591359
# Warn about any unexpected kwargs using the helper method
@@ -1388,8 +1388,8 @@ def forward(
13881388
start_logits = model_outputs["start_logits"]
13891389
end_logits = model_outputs["end_logits"]
13901390

1391-
if return_dict:
1392-
return {"start_logits": start_logits, "end_logits": end_logits}
1391+
if not return_dict:
1392+
return (start_logits, end_logits)
13931393

13941394
# converts output to namedtuple for pipelines post-processing
13951395
return QuestionAnsweringModelOutput(start_logits=start_logits, end_logits=end_logits)
@@ -1468,7 +1468,7 @@ def forward(
14681468
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
14691469
token_type_ids: Optional[Union[torch.Tensor, np.ndarray]] = None,
14701470
*,
1471-
return_dict: bool = False,
1471+
return_dict: bool = True,
14721472
**kwargs,
14731473
):
14741474
# Warn about any unexpected kwargs using the helper method
@@ -1505,8 +1505,8 @@ def forward(
15051505

15061506
logits = model_outputs["logits"]
15071507

1508-
if return_dict:
1509-
return {"logits": logits}
1508+
if not return_dict:
1509+
return (logits,)
15101510

15111511
# converts output to namedtuple for pipelines post-processing
15121512
return SequenceClassifierOutput(logits=logits)
@@ -1571,7 +1571,7 @@ def forward(
15711571
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
15721572
token_type_ids: Optional[Union[torch.Tensor, np.ndarray]] = None,
15731573
*,
1574-
return_dict: bool = False,
1574+
return_dict: bool = True,
15751575
**kwargs,
15761576
):
15771577
# Warn about any unexpected kwargs using the helper method
@@ -1608,8 +1608,8 @@ def forward(
16081608

16091609
logits = model_outputs["logits"]
16101610

1611-
if return_dict:
1612-
return {"logits": logits}
1611+
if not return_dict:
1612+
return (logits,)
16131613

16141614
return TokenClassifierOutput(logits=logits)
16151615

@@ -1667,7 +1667,7 @@ def forward(
16671667
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
16681668
token_type_ids: Optional[Union[torch.Tensor, np.ndarray]] = None,
16691669
*,
1670-
return_dict: bool = False,
1670+
return_dict: bool = True,
16711671
**kwargs,
16721672
):
16731673
# Warn about any unexpected kwargs using the helper method
@@ -1704,8 +1704,8 @@ def forward(
17041704

17051705
logits = model_outputs["logits"]
17061706

1707-
if return_dict:
1708-
return {"logits": logits}
1707+
if not return_dict:
1708+
return (logits,)
17091709

17101710
# converts output to namedtuple for pipelines post-processing
17111711
return MultipleChoiceModelOutput(logits=logits)
@@ -1770,7 +1770,7 @@ def forward(
17701770
self,
17711771
pixel_values: Union[torch.Tensor, np.ndarray],
17721772
*,
1773-
return_dict: bool = False,
1773+
return_dict: bool = True,
17741774
**kwargs,
17751775
):
17761776
# Warn about any unexpected kwargs using the helper method
@@ -1802,8 +1802,8 @@ def forward(
18021802

18031803
logits = model_outputs["logits"]
18041804

1805-
if return_dict:
1806-
return {"logits": logits}
1805+
if not return_dict:
1806+
return (logits,)
18071807

18081808
# converts output to namedtuple for pipelines post-processing
18091809
return ImageClassifierOutput(logits=logits)
@@ -1868,7 +1868,7 @@ def forward(
18681868
self,
18691869
pixel_values: Union[torch.Tensor, np.ndarray],
18701870
*,
1871-
return_dict: bool = False,
1871+
return_dict: bool = True,
18721872
**kwargs,
18731873
):
18741874
# Warn about any unexpected kwargs using the helper method
@@ -1900,8 +1900,8 @@ def forward(
19001900

19011901
logits = model_outputs["logits"]
19021902

1903-
if return_dict:
1904-
return {"logits": logits}
1903+
if not return_dict:
1904+
return (logits,)
19051905

19061906
# converts output to namedtuple for pipelines post-processing
19071907
return SemanticSegmenterOutput(logits=logits)
@@ -1996,7 +1996,7 @@ def forward(
19961996
attention_mask: Optional[Union[torch.Tensor, np.ndarray]] = None,
19971997
input_features: Optional[Union[torch.Tensor, np.ndarray]] = None,
19981998
*,
1999-
return_dict: bool = False,
1999+
return_dict: bool = True,
20002000
**kwargs,
20012001
):
20022002
# Warn about any unexpected kwargs using the helper method
@@ -2038,8 +2038,8 @@ def forward(
20382038

20392039
logits = model_outputs["logits"]
20402040

2041-
if return_dict:
2042-
return {"logits": logits}
2041+
if not return_dict:
2042+
return (logits,)
20432043

20442044
# converts output to namedtuple for pipelines post-processing
20452045
return SequenceClassifierOutput(logits=logits)
@@ -2092,7 +2092,7 @@ def forward(
20922092
self,
20932093
input_values: Optional[Union[torch.Tensor, np.ndarray]] = None,
20942094
*,
2095-
return_dict: bool = False,
2095+
return_dict: bool = True,
20962096
**kwargs,
20972097
):
20982098
# Warn about any unexpected kwargs using the helper method
@@ -2134,8 +2134,8 @@ def forward(
21342134

21352135
logits = model_outputs["logits"]
21362136

2137-
if return_dict:
2138-
return {"logits": logits}
2137+
if not return_dict:
2138+
return (logits,)
21392139

21402140
# converts output to namedtuple for pipelines post-processing
21412141
return CausalLMOutput(logits=logits)
@@ -2196,7 +2196,7 @@ def forward(
21962196
self,
21972197
input_values: Optional[Union[torch.Tensor, np.ndarray]] = None,
21982198
*,
2199-
return_dict: bool = False,
2199+
return_dict: bool = True,
22002200
**kwargs,
22012201
):
22022202
# Warn about any unexpected kwargs using the helper method
@@ -2231,8 +2231,8 @@ def forward(
22312231
logits = model_outputs["logits"]
22322232
embeddings = model_outputs["embeddings"]
22332233

2234-
if return_dict:
2235-
return {"logits": logits, "embeddings": embeddings}
2234+
if not return_dict:
2235+
return (logits, embeddings)
22362236

22372237
# converts output to namedtuple for pipelines post-processing
22382238
return XVectorOutput(logits=logits, embeddings=embeddings)
@@ -2285,7 +2285,7 @@ def forward(
22852285
self,
22862286
input_values: Optional[Union[torch.Tensor, np.ndarray]] = None,
22872287
*,
2288-
return_dict: bool = False,
2288+
return_dict: bool = True,
22892289
**kwargs,
22902290
):
22912291
# Warn about any unexpected kwargs using the helper method
@@ -2305,8 +2305,8 @@ def forward(
23052305

23062306
logits = model_outputs["logits"]
23072307

2308-
if return_dict:
2309-
return {"logits": logits}
2308+
if not return_dict:
2309+
return (logits,)
23102310

23112311
# converts output to namedtuple for pipelines post-processing
23122312
return TokenClassifierOutput(logits=logits)
@@ -2353,7 +2353,7 @@ def forward(
23532353
self,
23542354
pixel_values: Union[torch.Tensor, np.ndarray],
23552355
*,
2356-
return_dict: bool = False,
2356+
return_dict: bool = True,
23572357
**kwargs,
23582358
):
23592359
# Warn about any unexpected kwargs using the helper method
@@ -2390,8 +2390,8 @@ def forward(
23902390
model_outputs = self._prepare_onnx_outputs(use_torch, onnx_outputs)
23912391
reconstruction = model_outputs["reconstruction"]
23922392

2393-
if return_dict:
2394-
return {"reconstruction": reconstruction}
2393+
if not return_dict:
2394+
return (reconstruction,)
23952395

23962396
return ImageSuperResolutionOutput(reconstruction=reconstruction)
23972397

tests/onnxruntime/test_modeling.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2138,17 +2138,16 @@ def test_compare_to_transformers(self, model_arch):
21382138

21392139
for input_type in ["pt", "np"]:
21402140
tokens = tokenizer(text, return_tensors=input_type)
2141-
# Test default behavior (return_dict=False)
2141+
# Test default behavior (return_dict=True)
21422142
onnx_outputs = onnx_model(**tokens)
21432143
self.assertIsInstance(onnx_outputs, BaseModelOutput)
21442144
self.assertIn("last_hidden_state", onnx_outputs)
21452145
self.assertIsInstance(onnx_outputs.last_hidden_state, self.TENSOR_ALIAS_TO_TYPE[input_type])
21462146

2147-
# Test return_dict=True
2148-
onnx_outputs_dict = onnx_model(**tokens, return_dict=True)
2149-
self.assertIsInstance(onnx_outputs_dict, dict)
2150-
self.assertIn("last_hidden_state", onnx_outputs_dict)
2151-
self.assertIsInstance(onnx_outputs_dict["last_hidden_state"], self.TENSOR_ALIAS_TO_TYPE[input_type])
2147+
# Test return_dict=False
2148+
onnx_outputs_dict = onnx_model(**tokens, return_dict=False)
2149+
self.assertIsInstance(onnx_outputs_dict, tuple)
2150+
self.assertIsInstance(onnx_outputs_dict[0], self.TENSOR_ALIAS_TO_TYPE[input_type])
21522151

21532152
# compare tensor outputs
21542153
torch.testing.assert_close(

0 commit comments

Comments
 (0)