Skip to content

Commit ee22fcd

Browse files
authored
Fix missing input dimension in StackClassifier with ColumnTransformer (#1201)
* Fix missing input dimension in StackClassifier with ColumnTransformer Signed-off-by: xadupre <[email protected]> * changelogs Signed-off-by: xadupre <[email protected]> --------- Signed-off-by: xadupre <[email protected]>
1 parent b6586f1 commit ee22fcd

File tree

4 files changed

+107
-1
lines changed

4 files changed

+107
-1
lines changed

CHANGELOGS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## 1.20.0
44

5+
* Fixes missing dimension (number of features) in StackingClassifier
6+
[#1201](https://github.com/onnx/sklearn-onnx/issues/1201)
57
* Fixes CastTransformer output type
68
[#1200](https://github.com/onnx/sklearn-onnx/issues/1200)
79
* Fixes unknown_value=np.nan in OrdinalEncoder

skl2onnx/common/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def get_column_index(i, inputs):
9191
return 0, 0
9292
vi = 0
9393
pos = 0
94+
assert (
95+
len(inputs[0].type.shape) == 2
96+
), f"Unexpect rank={len(inputs[0].type.shape)} for inputs={inputs}, i={i}"
9497
end = inputs[0].type.shape[1] if isinstance(inputs[0].type, TensorType) else 1
9598
if end is None:
9699
raise RuntimeError(

skl2onnx/operator_converters/stacking.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,10 @@ def convert_sklearn_stacking_classifier(
137137

138138
merged_proba_tensor = _transform(scope, operator, container, stacking_op)
139139
merge_proba = scope.declare_local_variable(
140-
"stack_merge_proba", operator.inputs[0].type.__class__()
140+
"stack_merge_proba",
141+
operator.inputs[0].type.__class__(
142+
[None, stacking_op.final_estimator_.n_features_in_]
143+
),
141144
)
142145
container.add_node("Identity", [merged_proba_tensor], [merge_proba.onnx_name])
143146
prob = _fetch_scores(

tests/test_sklearn_stacking.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,104 @@ def converter(scope, operator, container):
564564
got = sess.run(None, {"X": x})[0]
565565
self.assertEqual(got.shape[0], x.shape[0])
566566

567+
def test_model_stacking_classifier_column_transformer_issue_1199(self):
568+
# see https://github.com/onnx/sklearn-onnx/issues/1199
569+
import random
570+
571+
import numpy as np
572+
from skl2onnx import to_onnx
573+
from sklearn.compose import ColumnTransformer
574+
from sklearn.datasets import make_classification
575+
from sklearn.ensemble import StackingClassifier, RandomForestClassifier
576+
from sklearn.linear_model import LogisticRegression
577+
from sklearn.model_selection import train_test_split
578+
from sklearn.pipeline import Pipeline
579+
from sklearn.preprocessing import StandardScaler
580+
581+
np.random.seed(42)
582+
random.seed(42)
583+
584+
X, y = make_classification(n_samples=1000, n_features=5, random_state=42)
585+
586+
pipeline = Pipeline(
587+
steps=[
588+
(
589+
"stacking_classifier",
590+
StackingClassifier(
591+
estimators=[
592+
(
593+
"tree",
594+
Pipeline(
595+
[
596+
(
597+
"tree_column_selector",
598+
ColumnTransformer(
599+
[
600+
(
601+
"tree_cols",
602+
"passthrough",
603+
[0, 1, 2],
604+
)
605+
],
606+
remainder="drop",
607+
),
608+
),
609+
("tree_classifier", RandomForestClassifier()),
610+
]
611+
),
612+
)
613+
],
614+
final_estimator=Pipeline(
615+
[
616+
(
617+
"feature_combiner",
618+
ColumnTransformer(
619+
[
620+
(
621+
"standardize_proba",
622+
Pipeline(
623+
[
624+
(
625+
"logit_transform",
626+
StandardScaler(),
627+
)
628+
]
629+
),
630+
[0],
631+
),
632+
("other_features", "passthrough", [4, 5]),
633+
],
634+
remainder="drop",
635+
),
636+
),
637+
("final_logistic", LogisticRegression()),
638+
]
639+
),
640+
cv=2,
641+
stack_method="auto",
642+
passthrough=True,
643+
),
644+
)
645+
]
646+
)
647+
648+
X_train, X_test, y_train, y_test = train_test_split(
649+
X, y, test_size=0.2, random_state=4
650+
)
651+
pipeline.fit(X_train, y_train)
652+
expected = pipeline.predict_proba(X_train)
653+
model_onnx = to_onnx(
654+
pipeline,
655+
X_train[:1].astype(np.float32),
656+
verbose=1,
657+
options={"zipmap": False},
658+
)
659+
sess = InferenceSession(
660+
model_onnx.SerializeToString(), providers=["CPUExecutionProvider"]
661+
)
662+
got = sess.run(None, {"X": X_train.astype(np.float32)})
663+
assert_almost_equal(expected, got[1], decimal=5)
664+
567665

568666
if __name__ == "__main__":
569667
# import logging

0 commit comments

Comments
 (0)