From e4dce3cd80b5653c09bf05a41ef8f2fcf9dd5f7a Mon Sep 17 00:00:00 2001 From: Villu Ruusmann Date: Sun, 4 Jul 2021 13:44:46 +0300 Subject: [PATCH] Refactored outlier detection model schemas --- .../sklearn/SkLearnOutlierTransformation.java | 77 +++++++++++++++++++ .../ensemble/iforest/IsolationForest.java | 10 ++- src/main/java/sklearn/svm/OneClassSVM.java | 14 ++-- .../sklego/meta/EstimatorTransformer.java | 12 ++- .../jpmml/sklearn/OutlierDetectorTest.java | 9 ++- 5 files changed, 106 insertions(+), 16 deletions(-) create mode 100644 src/main/java/sklearn/SkLearnOutlierTransformation.java diff --git a/src/main/java/sklearn/SkLearnOutlierTransformation.java b/src/main/java/sklearn/SkLearnOutlierTransformation.java new file mode 100644 index 000000000..c41d99e7d --- /dev/null +++ b/src/main/java/sklearn/SkLearnOutlierTransformation.java @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2021 Villu Ruusmann + * + * This file is part of JPMML-SkLearn + * + * JPMML-SkLearn is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * JPMML-SkLearn is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with JPMML-SkLearn. If not, see . + */ +package sklearn; + +import java.util.Arrays; +import java.util.List; + +import org.dmg.pmml.DataType; +import org.dmg.pmml.Expression; +import org.dmg.pmml.FieldName; +import org.dmg.pmml.FieldRef; +import org.dmg.pmml.OpType; +import org.dmg.pmml.Output; +import org.dmg.pmml.OutputField; +import org.dmg.pmml.PMMLFunctions; +import org.jpmml.converter.FieldNameUtil; +import org.jpmml.converter.PMMLUtil; +import org.jpmml.converter.Transformation; + +public class SkLearnOutlierTransformation implements Transformation { + + @Override + public FieldName getName(FieldName name){ + return FieldNameUtil.create(Estimator.FIELD_PREDICT, name); + } + + @Override + public OpType getOpType(OpType opType){ + return OpType.CATEGORICAL; + } + + @Override + public DataType getDataType(DataType dataType){ + return DataType.INTEGER; + } + + @Override + public boolean isFinalResult(){ + return true; + } + + @Override + public Expression createExpression(FieldRef fieldRef){ + return PMMLUtil.createApply(PMMLFunctions.IF, fieldRef, PMMLUtil.createConstant(VALUE_OUTLIER), PMMLUtil.createConstant(VALUE_INLIER)); + } + + static + public void decorate(Output output){ + + if(output != null && output.hasOutputFields()){ + List outputFields = output.getOutputFields(); + + OutputField finalOutputField = outputFields.get(outputFields.size() - 1); + + PMMLUtil.addValues(finalOutputField, Arrays.asList(VALUE_OUTLIER, VALUE_INLIER)); + } + } + + public static final Integer VALUE_INLIER = +1; + public static final Integer VALUE_OUTLIER = -1; +} \ No newline at end of file diff --git a/src/main/java/sklearn/ensemble/iforest/IsolationForest.java b/src/main/java/sklearn/ensemble/iforest/IsolationForest.java index 43456b0bb..f7f8394a6 100644 --- a/src/main/java/sklearn/ensemble/iforest/IsolationForest.java +++ b/src/main/java/sklearn/ensemble/iforest/IsolationForest.java @@ -29,6 +29,7 @@ import org.dmg.pmml.FieldRef; import org.dmg.pmml.MiningFunction; import org.dmg.pmml.OpType; +import org.dmg.pmml.Output; import org.dmg.pmml.PMMLFunctions; import org.dmg.pmml.PMMLObject; import org.dmg.pmml.Visitor; @@ -52,6 +53,7 @@ import org.jpmml.python.HasArray; import sklearn.Estimator; import sklearn.Regressor; +import sklearn.SkLearnOutlierTransformation; import sklearn.SkLearnUtil; import sklearn.ensemble.EnsembleRegressor; import sklearn.ensemble.EnsembleUtil; @@ -223,9 +225,15 @@ public Expression createExpression(FieldRef fieldRef){ } }; + Transformation sklearnOutlier = new SkLearnOutlierTransformation(); + + Output output = ModelUtil.createPredictedOutput(FieldName.create("rawAnomalyScore"), OpType.CONTINUOUS, DataType.DOUBLE, normalizedAnomalyScore, decisionFunction, outlier, sklearnOutlier); + + SkLearnOutlierTransformation.decorate(output); + MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel())) .setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethod.AVERAGE, treeModels)) - .setOutput(ModelUtil.createPredictedOutput(FieldName.create("rawAnomalyScore"), OpType.CONTINUOUS, DataType.DOUBLE, normalizedAnomalyScore, decisionFunction, outlier)); + .setOutput(output); return TreeUtil.transform(this, miningModel); } diff --git a/src/main/java/sklearn/svm/OneClassSVM.java b/src/main/java/sklearn/svm/OneClassSVM.java index 744ec30da..70fdb0bef 100644 --- a/src/main/java/sklearn/svm/OneClassSVM.java +++ b/src/main/java/sklearn/svm/OneClassSVM.java @@ -35,6 +35,7 @@ import org.jpmml.converter.Schema; import org.jpmml.converter.Transformation; import sklearn.Estimator; +import sklearn.SkLearnOutlierTransformation; public class OneClassSVM extends LibSVMRegressor { @@ -62,15 +63,11 @@ public Expression createExpression(FieldRef fieldRef){ } }; - SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema) - .setOutput(ModelUtil.createPredictedOutput(FieldName.create(Estimator.FIELD_DECISION_FUNCTION), OpType.CONTINUOUS, DataType.DOUBLE, outlier)); + Transformation sklearnOutlier = new SkLearnOutlierTransformation(); - Output output = supportVectorMachineModel.getOutput(); + Output output = ModelUtil.createPredictedOutput(FieldName.create(Estimator.FIELD_DECISION_FUNCTION), OpType.CONTINUOUS, DataType.DOUBLE, outlier, sklearnOutlier); List outputFields = output.getOutputFields(); - if(outputFields.size() != 2){ - throw new IllegalArgumentException(); - } OutputField decisionFunctionOutputField = outputFields.get(0); @@ -78,6 +75,11 @@ public Expression createExpression(FieldRef fieldRef){ decisionFunctionOutputField.setFinalResult(true); } + SkLearnOutlierTransformation.decorate(output); + + SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema) + .setOutput(output); + return supportVectorMachineModel; } } \ No newline at end of file diff --git a/src/main/java/sklego/meta/EstimatorTransformer.java b/src/main/java/sklego/meta/EstimatorTransformer.java index 6d6950dc4..3ec97fe55 100644 --- a/src/main/java/sklego/meta/EstimatorTransformer.java +++ b/src/main/java/sklego/meta/EstimatorTransformer.java @@ -30,7 +30,6 @@ import org.dmg.pmml.Output; import org.dmg.pmml.OutputField; import org.dmg.pmml.ResultFeature; -import org.jpmml.converter.BooleanFeature; import org.jpmml.converter.CategoricalFeature; import org.jpmml.converter.CategoricalLabel; import org.jpmml.converter.ContinuousFeature; @@ -165,14 +164,13 @@ public List encodeFeatures(List features, SkLearnEncoder encod switch(opType){ case CATEGORICAL: { - DataType dataType = finalOutputField.getDataType(); + OutputField finalPmmlOutputField = finalOutputField.getOutputField(); - switch(dataType){ - case BOOLEAN: - return Collections.singletonList(new BooleanFeature(encoder, finalOutputField)); - default: - throw new IllegalArgumentException(); + if(!finalPmmlOutputField.hasValues()){ + throw new IllegalArgumentException(); } + + return Collections.singletonList(new CategoricalFeature(encoder, finalOutputField, finalPmmlOutputField.getValues())); } case CONTINUOUS: { diff --git a/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java b/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java index 90b33a835..00371a851 100644 --- a/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java +++ b/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java @@ -18,18 +18,23 @@ */ package org.jpmml.sklearn; +import org.dmg.pmml.FieldName; +import org.jpmml.converter.FieldNameUtil; import org.jpmml.evaluator.testing.PMMLEquivalence; import org.junit.Test; +import sklearn.Estimator; public class OutlierDetectorTest extends SkLearnTest implements Algorithms, Datasets { @Test public void evaluateIsolationForestHousing() throws Exception { - evaluate(ISOLATION_FOREST, HOUSING, new PMMLEquivalence(5e-12, 5e-12)); + evaluate(ISOLATION_FOREST, HOUSING, excludeFields(OutlierDetectorTest.predictedValue), new PMMLEquivalence(5e-12, 5e-12)); } @Test public void evaluateOneClassSVMHousing() throws Exception { - evaluate(ONE_CLASS_SVM, HOUSING); + evaluate(ONE_CLASS_SVM, HOUSING, excludeFields(OutlierDetectorTest.predictedValue)); } + + private static final FieldName predictedValue = FieldNameUtil.create(Estimator.FIELD_PREDICT, "outlier"); } \ No newline at end of file