diff --git a/src/main/java/sklearn/SkLearnOutlierTransformation.java b/src/main/java/sklearn/SkLearnOutlierTransformation.java
new file mode 100644
index 000000000..c41d99e7d
--- /dev/null
+++ b/src/main/java/sklearn/SkLearnOutlierTransformation.java
@@ -0,0 +1,77 @@
+/*
+ * Copyright (c) 2021 Villu Ruusmann
+ *
+ * This file is part of JPMML-SkLearn
+ *
+ * JPMML-SkLearn is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * JPMML-SkLearn is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Affero General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with JPMML-SkLearn. If not, see .
+ */
+package sklearn;
+
+import java.util.Arrays;
+import java.util.List;
+
+import org.dmg.pmml.DataType;
+import org.dmg.pmml.Expression;
+import org.dmg.pmml.FieldName;
+import org.dmg.pmml.FieldRef;
+import org.dmg.pmml.OpType;
+import org.dmg.pmml.Output;
+import org.dmg.pmml.OutputField;
+import org.dmg.pmml.PMMLFunctions;
+import org.jpmml.converter.FieldNameUtil;
+import org.jpmml.converter.PMMLUtil;
+import org.jpmml.converter.Transformation;
+
+public class SkLearnOutlierTransformation implements Transformation {
+
+ @Override
+ public FieldName getName(FieldName name){
+ return FieldNameUtil.create(Estimator.FIELD_PREDICT, name);
+ }
+
+ @Override
+ public OpType getOpType(OpType opType){
+ return OpType.CATEGORICAL;
+ }
+
+ @Override
+ public DataType getDataType(DataType dataType){
+ return DataType.INTEGER;
+ }
+
+ @Override
+ public boolean isFinalResult(){
+ return true;
+ }
+
+ @Override
+ public Expression createExpression(FieldRef fieldRef){
+ return PMMLUtil.createApply(PMMLFunctions.IF, fieldRef, PMMLUtil.createConstant(VALUE_OUTLIER), PMMLUtil.createConstant(VALUE_INLIER));
+ }
+
+ static
+ public void decorate(Output output){
+
+ if(output != null && output.hasOutputFields()){
+ List outputFields = output.getOutputFields();
+
+ OutputField finalOutputField = outputFields.get(outputFields.size() - 1);
+
+ PMMLUtil.addValues(finalOutputField, Arrays.asList(VALUE_OUTLIER, VALUE_INLIER));
+ }
+ }
+
+ public static final Integer VALUE_INLIER = +1;
+ public static final Integer VALUE_OUTLIER = -1;
+}
\ No newline at end of file
diff --git a/src/main/java/sklearn/ensemble/iforest/IsolationForest.java b/src/main/java/sklearn/ensemble/iforest/IsolationForest.java
index 43456b0bb..f7f8394a6 100644
--- a/src/main/java/sklearn/ensemble/iforest/IsolationForest.java
+++ b/src/main/java/sklearn/ensemble/iforest/IsolationForest.java
@@ -29,6 +29,7 @@
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
+import org.dmg.pmml.Output;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Visitor;
@@ -52,6 +53,7 @@
import org.jpmml.python.HasArray;
import sklearn.Estimator;
import sklearn.Regressor;
+import sklearn.SkLearnOutlierTransformation;
import sklearn.SkLearnUtil;
import sklearn.ensemble.EnsembleRegressor;
import sklearn.ensemble.EnsembleUtil;
@@ -223,9 +225,15 @@ public Expression createExpression(FieldRef fieldRef){
}
};
+ Transformation sklearnOutlier = new SkLearnOutlierTransformation();
+
+ Output output = ModelUtil.createPredictedOutput(FieldName.create("rawAnomalyScore"), OpType.CONTINUOUS, DataType.DOUBLE, normalizedAnomalyScore, decisionFunction, outlier, sklearnOutlier);
+
+ SkLearnOutlierTransformation.decorate(output);
+
MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()))
.setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethod.AVERAGE, treeModels))
- .setOutput(ModelUtil.createPredictedOutput(FieldName.create("rawAnomalyScore"), OpType.CONTINUOUS, DataType.DOUBLE, normalizedAnomalyScore, decisionFunction, outlier));
+ .setOutput(output);
return TreeUtil.transform(this, miningModel);
}
diff --git a/src/main/java/sklearn/svm/OneClassSVM.java b/src/main/java/sklearn/svm/OneClassSVM.java
index 744ec30da..70fdb0bef 100644
--- a/src/main/java/sklearn/svm/OneClassSVM.java
+++ b/src/main/java/sklearn/svm/OneClassSVM.java
@@ -35,6 +35,7 @@
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import sklearn.Estimator;
+import sklearn.SkLearnOutlierTransformation;
public class OneClassSVM extends LibSVMRegressor {
@@ -62,15 +63,11 @@ public Expression createExpression(FieldRef fieldRef){
}
};
- SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema)
- .setOutput(ModelUtil.createPredictedOutput(FieldName.create(Estimator.FIELD_DECISION_FUNCTION), OpType.CONTINUOUS, DataType.DOUBLE, outlier));
+ Transformation sklearnOutlier = new SkLearnOutlierTransformation();
- Output output = supportVectorMachineModel.getOutput();
+ Output output = ModelUtil.createPredictedOutput(FieldName.create(Estimator.FIELD_DECISION_FUNCTION), OpType.CONTINUOUS, DataType.DOUBLE, outlier, sklearnOutlier);
List outputFields = output.getOutputFields();
- if(outputFields.size() != 2){
- throw new IllegalArgumentException();
- }
OutputField decisionFunctionOutputField = outputFields.get(0);
@@ -78,6 +75,11 @@ public Expression createExpression(FieldRef fieldRef){
decisionFunctionOutputField.setFinalResult(true);
}
+ SkLearnOutlierTransformation.decorate(output);
+
+ SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema)
+ .setOutput(output);
+
return supportVectorMachineModel;
}
}
\ No newline at end of file
diff --git a/src/main/java/sklego/meta/EstimatorTransformer.java b/src/main/java/sklego/meta/EstimatorTransformer.java
index 6d6950dc4..3ec97fe55 100644
--- a/src/main/java/sklego/meta/EstimatorTransformer.java
+++ b/src/main/java/sklego/meta/EstimatorTransformer.java
@@ -30,7 +30,6 @@
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
-import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
@@ -165,14 +164,13 @@ public List encodeFeatures(List features, SkLearnEncoder encod
switch(opType){
case CATEGORICAL:
{
- DataType dataType = finalOutputField.getDataType();
+ OutputField finalPmmlOutputField = finalOutputField.getOutputField();
- switch(dataType){
- case BOOLEAN:
- return Collections.singletonList(new BooleanFeature(encoder, finalOutputField));
- default:
- throw new IllegalArgumentException();
+ if(!finalPmmlOutputField.hasValues()){
+ throw new IllegalArgumentException();
}
+
+ return Collections.singletonList(new CategoricalFeature(encoder, finalOutputField, finalPmmlOutputField.getValues()));
}
case CONTINUOUS:
{
diff --git a/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java b/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java
index 90b33a835..00371a851 100644
--- a/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java
+++ b/src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java
@@ -18,18 +18,23 @@
*/
package org.jpmml.sklearn;
+import org.dmg.pmml.FieldName;
+import org.jpmml.converter.FieldNameUtil;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.junit.Test;
+import sklearn.Estimator;
public class OutlierDetectorTest extends SkLearnTest implements Algorithms, Datasets {
@Test
public void evaluateIsolationForestHousing() throws Exception {
- evaluate(ISOLATION_FOREST, HOUSING, new PMMLEquivalence(5e-12, 5e-12));
+ evaluate(ISOLATION_FOREST, HOUSING, excludeFields(OutlierDetectorTest.predictedValue), new PMMLEquivalence(5e-12, 5e-12));
}
@Test
public void evaluateOneClassSVMHousing() throws Exception {
- evaluate(ONE_CLASS_SVM, HOUSING);
+ evaluate(ONE_CLASS_SVM, HOUSING, excludeFields(OutlierDetectorTest.predictedValue));
}
+
+ private static final FieldName predictedValue = FieldNameUtil.create(Estimator.FIELD_PREDICT, "outlier");
}
\ No newline at end of file