15
15
16
16
17
17
def convert_sklearn_gradient_boosting_classifier (
18
- scope ,
19
- operator ,
20
- container ,
21
- op_type = "TreeEnsembleClassifier" ,
22
- op_domain = "ai.onnx.ml" ,
23
- op_version = 1 ,
18
+ scope ,
19
+ operator ,
20
+ container ,
21
+ op_type = "TreeEnsembleClassifier" ,
22
+ op_domain = "ai.onnx.ml" ,
23
+ op_version = 1 ,
24
24
):
25
25
dtype = guess_numpy_type (operator .inputs [0 ].type )
26
26
if dtype != np .float64 :
@@ -37,7 +37,6 @@ def convert_sklearn_gradient_boosting_classifier(
37
37
attrs ["name" ] = scope .get_unique_operator_name (op_type )
38
38
39
39
transform = "LOGISTIC" if op .n_classes_ == 2 else "SOFTMAX"
40
-
41
40
if op .init == "zero" :
42
41
loss = op ._loss if hasattr (op , "_loss" ) else op .loss_
43
42
if hasattr (loss , "K" ):
@@ -99,11 +98,11 @@ def convert_sklearn_gradient_boosting_classifier(
99
98
if dtype is not None :
100
99
for k in attrs :
101
100
if k in (
102
- "nodes_values" ,
103
- "class_weights" ,
104
- "target_weights" ,
105
- "nodes_hitrates" ,
106
- "base_values" ,
101
+ "nodes_values" ,
102
+ "class_weights" ,
103
+ "target_weights" ,
104
+ "nodes_hitrates" ,
105
+ "base_values" ,
107
106
):
108
107
attrs [k ] = np .array (attrs [k ], dtype = dtype )
109
108
@@ -155,12 +154,12 @@ def convert_sklearn_gradient_boosting_classifier(
155
154
156
155
157
156
def convert_sklearn_gradient_boosting_regressor (
158
- scope ,
159
- operator ,
160
- container ,
161
- op_type = "TreeEnsembleRegressor" ,
162
- op_domain = "ai.onnx.ml" ,
163
- op_version = 1 ,
157
+ scope ,
158
+ operator ,
159
+ container ,
160
+ op_type = "TreeEnsembleRegressor" ,
161
+ op_domain = "ai.onnx.ml" ,
162
+ op_version = 1 ,
164
163
):
165
164
op = operator .raw_operator
166
165
attrs = get_default_tree_regressor_attribute_pairs ()
@@ -207,11 +206,11 @@ def convert_sklearn_gradient_boosting_regressor(
207
206
if dtype is not None :
208
207
for k in attrs :
209
208
if k in (
210
- "nodes_values" ,
211
- "class_weights" ,
212
- "target_weights" ,
213
- "nodes_hitrates" ,
214
- "base_values" ,
209
+ "nodes_values" ,
210
+ "class_weights" ,
211
+ "target_weights" ,
212
+ "nodes_hitrates" ,
213
+ "base_values" ,
215
214
):
216
215
attrs [k ] = np .array (attrs [k ], dtype = dtype )
217
216
0 commit comments