Skip to content

Commit c73787f

Browse files
committed
Added support for the 'H2OEstimator.pmml_classes_' attribute
1 parent 620c521 commit c73787f

File tree

6 files changed

+439
-7
lines changed

6 files changed

+439
-7
lines changed

pmml-sklearn-h2o/src/main/java/h2o/estimators/H2OEstimator.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,10 @@
5151
import org.jpmml.python.ClassDictUtil;
5252
import org.jpmml.sklearn.Encodable;
5353
import org.jpmml.sklearn.SkLearnEncoder;
54+
import sklearn.Classifier;
5455
import sklearn.Estimator;
5556
import sklearn.HasClasses;
57+
import sklearn2pmml.SkLearn2PMMLFields;
5658

5759
public class H2OEstimator extends Estimator implements HasClasses, Encodable {
5860

@@ -107,6 +109,12 @@ public int getNumberOfOutputs(){
107109
public List<?> getClasses(){
108110
MojoModel mojoModel = getMojoModel();
109111

112+
if(containsKey(SkLearn2PMMLFields.PMML_CLASSES)){
113+
List<?> values = getListLike(SkLearn2PMMLFields.PMML_CLASSES);
114+
115+
return Classifier.canonicalizeValues(values);
116+
}
117+
110118
int responseIdx = mojoModel.getResponseIdx();
111119

112120
String[] responseValues = mojoModel.getDomainValues(responseIdx);

pmml-sklearn-h2o/src/test/java/org/jpmml/sklearn/h2o/testing/SkLearnH2OTest.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ public void evaluateGradientBoostingAuto() throws Exception {
106106
evaluate("H2OGradientBoosting", AUTO);
107107
}
108108

109+
@Test
110+
public void evaluateLinearRegressionAuto() throws Exception {
111+
evaluate("H2OLinearRegression", AUTO);
112+
}
113+
109114
@Test
110115
public void evaluateLogisticRegressionAudit() throws Exception {
111116
String[] targetFields = createTargetFields(AUDIT_ADJUSTED);
@@ -114,8 +119,8 @@ public void evaluateLogisticRegressionAudit() throws Exception {
114119
}
115120

116121
@Test
117-
public void evaluateLinearRegressionAuto() throws Exception {
118-
evaluate("H2OLinearRegression", AUTO);
122+
public void evaluateOrdinalRegressionAuto() throws Exception {
123+
evaluate("H2OOrdinalRegression", AUTO);
119124
}
120125

121126
@Test

0 commit comments

Comments
 (0)