Skip to content

Commit b6586f1

Browse files
authored
Fix CastTransformer output type (#1200)
* update changelogs Signed-off-by: xadupre <[email protected]> * Fix unknown_value=np.nan in OrdinalEncoder Signed-off-by: xadupre <[email protected]> * changelogs Signed-off-by: xadupre <[email protected]> * Fix CastTransformer output type Signed-off-by: xadupre <[email protected]> * changes Signed-off-by: xadupre <[email protected]> * disable the test for old version of scikit-learn --------- Signed-off-by: xadupre <[email protected]>
1 parent 46c85e5 commit b6586f1

File tree

3 files changed

+51
-0
lines changed

3 files changed

+51
-0
lines changed

CHANGELOGS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 1.20.0
44

5+
* Fixes CastTransformer output type
6+
[#1200](https://github.com/onnx/sklearn-onnx/issues/1200)
57
* Fixes unknown_value=np.nan in OrdinalEncoder
68
[#1198](https://github.com/onnx/sklearn-onnx/issues/1198)
79
* Enhance OrdinalEncoder conversion to handle infrequent categories

skl2onnx/_parse.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class OutlierMixin:
5252
from .common.data_types import (
5353
DictionaryType,
5454
Int64TensorType,
55+
DoubleTensorType,
5556
SequenceType,
5657
StringTensorType,
5758
TensorType,
@@ -62,6 +63,7 @@ class OutlierMixin:
6263
from .common.utils_checking import check_signature
6364
from .common.utils_classifier import get_label_classes
6465
from .common.utils_sklearn import _process_options
66+
from .sklapi import CastTransformer
6567

6668

6769
do_not_merge_columns = tuple(
@@ -250,6 +252,18 @@ def _parse_sklearn_simple_model(scope, model, inputs, custom_parsers=None, alias
250252
otype = guess_tensor_type(inputs[0].type)
251253
variable = scope.declare_local_variable("variable", otype)
252254
this_operator.outputs.append(variable)
255+
elif type(model) in {CastTransformer}:
256+
dtype = model.dtype
257+
if dtype == np.float32:
258+
cls = FloatTensorType
259+
elif dtype == np.float64:
260+
cls = DoubleTensorType
261+
elif dtype == np.int64:
262+
cls = Int64TensorType
263+
else:
264+
raise NotImplementedError(f"Unexpected dtype={dtype} for model={model}")
265+
variable = scope.declare_local_variable("cast", cls())
266+
this_operator.outputs.append(variable)
253267
else:
254268
if hasattr(model, "get_feature_names_out"):
255269
try:

tests/test_sklearn_imputer_converter.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import pandas as pd
1111
from numpy.testing import assert_almost_equal
1212
import sklearn
13+
from sklearn.tree import DecisionTreeRegressor
14+
from sklearn.pipeline import Pipeline
1315

1416
try:
1517
from sklearn.preprocessing import Imputer
@@ -24,8 +26,10 @@
2426

2527
from onnxruntime import __version__ as ort_version
2628

29+
from skl2onnx.sklapi import CastTransformer
2730
from skl2onnx import convert_sklearn
2831
from skl2onnx.common.data_types import (
32+
DoubleTensorType,
2933
FloatTensorType,
3034
Int64TensorType,
3135
StringTensorType,
@@ -125,6 +129,37 @@ def test_simple_imputer_float_inputs(self):
125129
basename="SklearnSimpleImputerMeanFloat32",
126130
)
127131

132+
@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
133+
@unittest.skipIf(
134+
pv.Version(skl_ver) <= pv.Version("1.4.0"),
135+
reason="unexpected pipeline + transform",
136+
)
137+
def test_simple_imputer_double_inputs(self):
138+
model = Pipeline(
139+
[
140+
("cast32", CastTransformer(dtype=np.float32)),
141+
("imputer", SimpleImputer()),
142+
("dt", DecisionTreeRegressor(max_depth=2)),
143+
]
144+
)
145+
data = np.array([[1, 2], [np.nan, 3], [7, 6]], dtype=np.float64)
146+
y = [0, 1, 0]
147+
model.fit(data, y)
148+
149+
model_onnx = convert_sklearn(
150+
model,
151+
"double",
152+
[("input", DoubleTensorType([None, 2]))],
153+
final_types=[("y", FloatTensorType([None, 1]))],
154+
target_opset=TARGET_OPSET,
155+
)
156+
dump_data_and_model(
157+
np.array(data, dtype=np.float64),
158+
model,
159+
model_onnx,
160+
basename="SklearnSimpleImputerDoubleInputs",
161+
)
162+
128163
@unittest.skipIf(SimpleImputer is None, reason="SimpleImputer changed in 0.20")
129164
@unittest.skipIf(
130165
pv.Version(ort_version) <= pv.Version("1.11.0"),

0 commit comments

Comments
 (0)