Skip to content

Commit

Permalink
Improved support for predicate attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Feb 26, 2023
1 parent 9dc988b commit 375760a
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import java.util.AbstractList;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -163,13 +162,13 @@ public ScalarLabel get(int index){

String name = TupleUtil.extractElement(step, 0, String.class);
Estimator estimator = TupleUtil.extractElement(step, 1, Estimator.class);
String predicate = TupleUtil.extractElement(step, 2, String.class);
Object expr = TupleUtil.extractElement(step, 2, Object.class);

estimators.add(estimator);

Schema segmentSchema = schema.toRelabeledSchema(multiLabel.getLabel(i));

Predicate pmmlPredicate = EvaluatableUtil.translatePredicate(predicate, Collections.emptyList(), scope);
Predicate predicate = EvaluatableUtil.translatePredicate(expr, scope);

Model model = estimator.encode(segmentSchema);

Expand All @@ -181,7 +180,7 @@ public ScalarLabel get(int index){
schema = link.augmentSchema(model, segmentSchema);
}

Segment segment = new Segment(pmmlPredicate, model)
Segment segment = new Segment(predicate, model)
.setId(name);

segmentation.addSegments(segment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
*/
package sklearn2pmml.ensemble;

import java.util.Collections;
import java.util.List;

import org.dmg.pmml.MiningFunction;
Expand Down Expand Up @@ -72,17 +71,17 @@ private MiningModel encodeModel(MiningFunction miningFunction, List<Object[]> st

String name = TupleUtil.extractElement(step, 0, String.class);
Estimator estimator = TupleUtil.extractElement(step, 1, Estimator.class);
String predicate = TupleUtil.extractElement(step, 2, String.class);
Object expr = TupleUtil.extractElement(step, 2, Object.class);

if(estimator.getMiningFunction() != miningFunction){
throw new IllegalArgumentException();
}

Predicate pmmlPredicate = EvaluatableUtil.translatePredicate(predicate, Collections.emptyList(), scope);
Predicate predicate = EvaluatableUtil.translatePredicate(expr, scope);

Model model = estimator.encode(schema);

Segment segment = new Segment(pmmlPredicate, model)
Segment segment = new Segment(predicate, model)
.setId(name);

segmentation.addSegments(segment);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,12 @@ public RuleSetModel encodeModel(Schema schema){
Scope scope = new DataFrameScope("X", features);

for(Object[] rule : rules){
String predicate = TupleUtil.extractElement(rule, 0, String.class);
Object expr = TupleUtil.extractElement(rule, 0, Object.class);
String score = TupleUtil.extractElement(rule, 1, String.class);

Predicate pmmlPredicate = EvaluatableUtil.translatePredicate(predicate, Collections.emptyList(), scope);
Predicate predicate = EvaluatableUtil.translatePredicate(expr, scope);

SimpleRule simpleRule = new SimpleRule(score, pmmlPredicate);
SimpleRule simpleRule = new SimpleRule(score, predicate);

ruleSet.addRules(simpleRule);
}
Expand Down
12 changes: 12 additions & 0 deletions pmml-sklearn/src/main/java/sklearn2pmml/util/EvaluatableUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,18 @@ public org.dmg.pmml.Expression translateExpression(String expr, List<String> fun
return expressionTranslator.translateExpression(expr);
}

static
public org.dmg.pmml.Predicate translatePredicate(Object expr, Scope scope){

if(expr instanceof Predicate){
Predicate predicate = (Predicate)expr;

return predicate.translate(scope);
}

return translatePredicate((String)expr, Collections.emptyList(), scope);
}

static
public org.dmg.pmml.Predicate translatePredicate(String expr, List<String> functionDefs, Scope scope){
PredicateTranslator predicateTranslator = new PredicateTranslator(scope);
Expand Down
3 changes: 2 additions & 1 deletion pmml-sklearn/src/test/resources/extensions/sklearn2pmml.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from sklearn2pmml.preprocessing import CutTransformer
from sklearn2pmml.ruleset import RuleSetClassifier
from sklearn2pmml.tree.chaid import CHAIDClassifier, CHAIDRegressor
from sklearn2pmml.util import Predicate

import numpy
import sys
Expand Down Expand Up @@ -122,7 +123,7 @@ def build_ruleset_iris(iris_df, name):
iris_X, iris_y = split_csv(iris_df)

classifier = RuleSetClassifier([
("X['Petal.Length'] >= 2.45 and X['Petal.Width'] < 1.75", "versicolor"),
(Predicate("X['Petal.Length'] >= 2.45 and X['Petal.Width'] < 1.75"), "versicolor"),
("X['Petal.Length'] >= 2.45", "virginica")
], default_score = "setosa")
pipeline = PMMLPipeline([
Expand Down
Binary file modified pmml-sklearn/src/test/resources/pkl/RuleSetIris.pkl
Binary file not shown.

0 comments on commit 375760a

Please sign in to comment.