Skip to content

Commit

Permalink
Added support for the 'H2OEstimator.pmml_classes_' attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Mar 24, 2024
1 parent 620c521 commit c73787f
Show file tree
Hide file tree
Showing 6 changed files with 439 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@
import org.jpmml.python.ClassDictUtil;
import org.jpmml.sklearn.Encodable;
import org.jpmml.sklearn.SkLearnEncoder;
import sklearn.Classifier;
import sklearn.Estimator;
import sklearn.HasClasses;
import sklearn2pmml.SkLearn2PMMLFields;

public class H2OEstimator extends Estimator implements HasClasses, Encodable {

Expand Down Expand Up @@ -107,6 +109,12 @@ public int getNumberOfOutputs(){
public List<?> getClasses(){
MojoModel mojoModel = getMojoModel();

if(containsKey(SkLearn2PMMLFields.PMML_CLASSES)){
List<?> values = getListLike(SkLearn2PMMLFields.PMML_CLASSES);

return Classifier.canonicalizeValues(values);
}

int responseIdx = mojoModel.getResponseIdx();

String[] responseValues = mojoModel.getDomainValues(responseIdx);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ public void evaluateGradientBoostingAuto() throws Exception {
evaluate("H2OGradientBoosting", AUTO);
}

@Test
public void evaluateLinearRegressionAuto() throws Exception {
evaluate("H2OLinearRegression", AUTO);
}

@Test
public void evaluateLogisticRegressionAudit() throws Exception {
String[] targetFields = createTargetFields(AUDIT_ADJUSTED);
Expand All @@ -114,8 +119,8 @@ public void evaluateLogisticRegressionAudit() throws Exception {
}

@Test
public void evaluateLinearRegressionAuto() throws Exception {
evaluate("H2OLinearRegression", AUTO);
public void evaluateOrdinalRegressionAuto() throws Exception {
evaluate("H2OOrdinalRegression", AUTO);
}

@Test
Expand Down
Loading

0 comments on commit c73787f

Please sign in to comment.