Skip to content

Commit 51cb819

Browse files
author
Maksym Sydorchuk
committed
refactoring
Signed-off-by: Maksym Sydorchuk <[email protected]>
1 parent bfea7e1 commit 51cb819

File tree

1 file changed

+22
-23
lines changed

1 file changed

+22
-23
lines changed

skl2onnx/operator_converters/gradient_boosting.py

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515

1616

1717
def convert_sklearn_gradient_boosting_classifier(
18-
scope,
19-
operator,
20-
container,
21-
op_type="TreeEnsembleClassifier",
22-
op_domain="ai.onnx.ml",
23-
op_version=1,
18+
scope,
19+
operator,
20+
container,
21+
op_type="TreeEnsembleClassifier",
22+
op_domain="ai.onnx.ml",
23+
op_version=1,
2424
):
2525
dtype = guess_numpy_type(operator.inputs[0].type)
2626
if dtype != np.float64:
@@ -37,7 +37,6 @@ def convert_sklearn_gradient_boosting_classifier(
3737
attrs["name"] = scope.get_unique_operator_name(op_type)
3838

3939
transform = "LOGISTIC" if op.n_classes_ == 2 else "SOFTMAX"
40-
4140
if op.init == "zero":
4241
loss = op._loss if hasattr(op, "_loss") else op.loss_
4342
if hasattr(loss, "K"):
@@ -99,11 +98,11 @@ def convert_sklearn_gradient_boosting_classifier(
9998
if dtype is not None:
10099
for k in attrs:
101100
if k in (
102-
"nodes_values",
103-
"class_weights",
104-
"target_weights",
105-
"nodes_hitrates",
106-
"base_values",
101+
"nodes_values",
102+
"class_weights",
103+
"target_weights",
104+
"nodes_hitrates",
105+
"base_values",
107106
):
108107
attrs[k] = np.array(attrs[k], dtype=dtype)
109108

@@ -155,12 +154,12 @@ def convert_sklearn_gradient_boosting_classifier(
155154

156155

157156
def convert_sklearn_gradient_boosting_regressor(
158-
scope,
159-
operator,
160-
container,
161-
op_type="TreeEnsembleRegressor",
162-
op_domain="ai.onnx.ml",
163-
op_version=1,
157+
scope,
158+
operator,
159+
container,
160+
op_type="TreeEnsembleRegressor",
161+
op_domain="ai.onnx.ml",
162+
op_version=1,
164163
):
165164
op = operator.raw_operator
166165
attrs = get_default_tree_regressor_attribute_pairs()
@@ -207,11 +206,11 @@ def convert_sklearn_gradient_boosting_regressor(
207206
if dtype is not None:
208207
for k in attrs:
209208
if k in (
210-
"nodes_values",
211-
"class_weights",
212-
"target_weights",
213-
"nodes_hitrates",
214-
"base_values",
209+
"nodes_values",
210+
"class_weights",
211+
"target_weights",
212+
"nodes_hitrates",
213+
"base_values",
215214
):
216215
attrs[k] = np.array(attrs[k], dtype=dtype)
217216

0 commit comments

Comments
 (0)