Skip to content

Commit

Permalink
Refactored outlier detection model schemas
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 4, 2021
1 parent df83171 commit e4dce3c
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 16 deletions.
77 changes: 77 additions & 0 deletions src/main/java/sklearn/SkLearnOutlierTransformation.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2021 Villu Ruusmann
*
* This file is part of JPMML-SkLearn
*
* JPMML-SkLearn is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-SkLearn is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-SkLearn. If not, see <http://www.gnu.org/licenses/>.
*/
package sklearn;

import java.util.Arrays;
import java.util.List;

import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.PMMLFunctions;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.converter.PMMLUtil;
import org.jpmml.converter.Transformation;

public class SkLearnOutlierTransformation implements Transformation {

@Override
public FieldName getName(FieldName name){
return FieldNameUtil.create(Estimator.FIELD_PREDICT, name);
}

@Override
public OpType getOpType(OpType opType){
return OpType.CATEGORICAL;
}

@Override
public DataType getDataType(DataType dataType){
return DataType.INTEGER;
}

@Override
public boolean isFinalResult(){
return true;
}

@Override
public Expression createExpression(FieldRef fieldRef){
return PMMLUtil.createApply(PMMLFunctions.IF, fieldRef, PMMLUtil.createConstant(VALUE_OUTLIER), PMMLUtil.createConstant(VALUE_INLIER));
}

static
public void decorate(Output output){

if(output != null && output.hasOutputFields()){
List<OutputField> outputFields = output.getOutputFields();

OutputField finalOutputField = outputFields.get(outputFields.size() - 1);

PMMLUtil.addValues(finalOutputField, Arrays.asList(VALUE_OUTLIER, VALUE_INLIER));
}
}

public static final Integer VALUE_INLIER = +1;
public static final Integer VALUE_OUTLIER = -1;
}
10 changes: 9 additions & 1 deletion src/main/java/sklearn/ensemble/iforest/IsolationForest.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.PMMLFunctions;
import org.dmg.pmml.PMMLObject;
import org.dmg.pmml.Visitor;
Expand All @@ -52,6 +53,7 @@
import org.jpmml.python.HasArray;
import sklearn.Estimator;
import sklearn.Regressor;
import sklearn.SkLearnOutlierTransformation;
import sklearn.SkLearnUtil;
import sklearn.ensemble.EnsembleRegressor;
import sklearn.ensemble.EnsembleUtil;
Expand Down Expand Up @@ -223,9 +225,15 @@ public Expression createExpression(FieldRef fieldRef){
}
};

Transformation sklearnOutlier = new SkLearnOutlierTransformation();

Output output = ModelUtil.createPredictedOutput(FieldName.create("rawAnomalyScore"), OpType.CONTINUOUS, DataType.DOUBLE, normalizedAnomalyScore, decisionFunction, outlier, sklearnOutlier);

SkLearnOutlierTransformation.decorate(output);

MiningModel miningModel = new MiningModel(MiningFunction.REGRESSION, ModelUtil.createMiningSchema(schema.getLabel()))
.setSegmentation(MiningModelUtil.createSegmentation(MultipleModelMethod.AVERAGE, treeModels))
.setOutput(ModelUtil.createPredictedOutput(FieldName.create("rawAnomalyScore"), OpType.CONTINUOUS, DataType.DOUBLE, normalizedAnomalyScore, decisionFunction, outlier));
.setOutput(output);

return TreeUtil.transform(this, miningModel);
}
Expand Down
14 changes: 8 additions & 6 deletions src/main/java/sklearn/svm/OneClassSVM.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.jpmml.converter.Schema;
import org.jpmml.converter.Transformation;
import sklearn.Estimator;
import sklearn.SkLearnOutlierTransformation;

public class OneClassSVM extends LibSVMRegressor {

Expand Down Expand Up @@ -62,22 +63,23 @@ public Expression createExpression(FieldRef fieldRef){
}
};

SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema)
.setOutput(ModelUtil.createPredictedOutput(FieldName.create(Estimator.FIELD_DECISION_FUNCTION), OpType.CONTINUOUS, DataType.DOUBLE, outlier));
Transformation sklearnOutlier = new SkLearnOutlierTransformation();

Output output = supportVectorMachineModel.getOutput();
Output output = ModelUtil.createPredictedOutput(FieldName.create(Estimator.FIELD_DECISION_FUNCTION), OpType.CONTINUOUS, DataType.DOUBLE, outlier, sklearnOutlier);

List<OutputField> outputFields = output.getOutputFields();
if(outputFields.size() != 2){
throw new IllegalArgumentException();
}

OutputField decisionFunctionOutputField = outputFields.get(0);

if(!decisionFunctionOutputField.isFinalResult()){
decisionFunctionOutputField.setFinalResult(true);
}

SkLearnOutlierTransformation.decorate(output);

SupportVectorMachineModel supportVectorMachineModel = super.encodeModel(schema)
.setOutput(output);

return supportVectorMachineModel;
}
}
12 changes: 5 additions & 7 deletions src/main/java/sklego/meta/EstimatorTransformer.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
Expand Down Expand Up @@ -165,14 +164,13 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
switch(opType){
case CATEGORICAL:
{
DataType dataType = finalOutputField.getDataType();
OutputField finalPmmlOutputField = finalOutputField.getOutputField();

switch(dataType){
case BOOLEAN:
return Collections.singletonList(new BooleanFeature(encoder, finalOutputField));
default:
throw new IllegalArgumentException();
if(!finalPmmlOutputField.hasValues()){
throw new IllegalArgumentException();
}

return Collections.singletonList(new CategoricalFeature(encoder, finalOutputField, finalPmmlOutputField.getValues()));
}
case CONTINUOUS:
{
Expand Down
9 changes: 7 additions & 2 deletions src/test/java/org/jpmml/sklearn/OutlierDetectorTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,23 @@
*/
package org.jpmml.sklearn;

import org.dmg.pmml.FieldName;
import org.jpmml.converter.FieldNameUtil;
import org.jpmml.evaluator.testing.PMMLEquivalence;
import org.junit.Test;
import sklearn.Estimator;

public class OutlierDetectorTest extends SkLearnTest implements Algorithms, Datasets {

@Test
public void evaluateIsolationForestHousing() throws Exception {
evaluate(ISOLATION_FOREST, HOUSING, new PMMLEquivalence(5e-12, 5e-12));
evaluate(ISOLATION_FOREST, HOUSING, excludeFields(OutlierDetectorTest.predictedValue), new PMMLEquivalence(5e-12, 5e-12));
}

@Test
public void evaluateOneClassSVMHousing() throws Exception {
evaluate(ONE_CLASS_SVM, HOUSING);
evaluate(ONE_CLASS_SVM, HOUSING, excludeFields(OutlierDetectorTest.predictedValue));
}

private static final FieldName predictedValue = FieldNameUtil.create(Estimator.FIELD_PREDICT, "outlier");
}

0 comments on commit e4dce3c

Please sign in to comment.