Skip to content

Commit

Permalink
Improved support for base2-encoded features
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jun 23, 2020
1 parent d3b9a2e commit 346915c
Show file tree
Hide file tree
Showing 5 changed files with 1,939 additions and 1,905 deletions.
15 changes: 15 additions & 0 deletions src/main/java/category_encoders/BaseNEncoder.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
import com.google.common.base.Strings;
import com.google.common.collect.LinkedHashMultimap;
import com.google.common.collect.SetMultimap;
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.OpType;
import org.jpmml.converter.BinaryFeature;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.FeatureUtil;
import org.jpmml.python.ClassDictUtil;
Expand Down Expand Up @@ -124,6 +128,17 @@ public FieldName getDerivedName(){
}
};

if(base == 2){
ContinuousFeature continuousFeature = baseFeature.toContinuousFeature();

DerivedField derivedField = (DerivedField)encoder.getField(continuousFeature.getName());

// XXX
derivedField.setOpType(OpType.CATEGORICAL);

baseFeature = new BinaryFeature(encoder, derivedField, 1);
}

baseFeatures.add(baseFeature);
}

Expand Down
23 changes: 19 additions & 4 deletions src/main/java/category_encoders/BaseNFeature.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,15 @@
import java.util.Set;
import java.util.function.Supplier;

import com.google.common.collect.Iterables;
import com.google.common.collect.SetMultimap;
import org.dmg.pmml.Apply;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.FieldRef;
import org.dmg.pmml.NormDiscrete;
import org.dmg.pmml.PMMLFunctions;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
Expand Down Expand Up @@ -65,25 +67,38 @@ public BaseNFeature(PMMLEncoder encoder, FieldName name, DataType dataType, int
public ContinuousFeature toContinuousFeature(){
FieldName name = getName();
DataType dataType = getDataType();
int base = getBase();
SetMultimap<Integer, ?> values = getValues();

Supplier<Expression> expressionSupplier = () -> {
Map<Integer, ? extends Collection<?>> valueMap = values.asMap();

if(base == 2){
Collection<?> categories = valueMap.get(1);

if(categories != null && categories.size() == 1){
Object category = Iterables.getOnlyElement(categories);

return new NormDiscrete(name, category);
}
}

Apply apply = null;

Apply prevIfApply = null;

Set<? extends Map.Entry<Integer, ? extends Collection<?>>> entries = (values.asMap()).entrySet();
Set<? extends Map.Entry<Integer, ? extends Collection<?>>> entries = valueMap.entrySet();
for(Map.Entry<Integer, ? extends Collection<?>> entry : entries){
Integer baseValue = entry.getKey();
Collection<?> categories = entry.getValue();

Apply isInApply = PMMLUtil.createApply(PMMLFunctions.ISIN, new FieldRef(name));
Apply valueApply = PMMLUtil.createApply((categories.size() == 1 ? PMMLFunctions.EQUAL : PMMLFunctions.ISIN), new FieldRef(name));

for(Object category : categories){
isInApply.addExpressions(PMMLUtil.createConstant(category, dataType));
valueApply.addExpressions(PMMLUtil.createConstant(category, dataType));
}

Apply ifApply = PMMLUtil.createApply(PMMLFunctions.IF, isInApply)
Apply ifApply = PMMLUtil.createApply(PMMLFunctions.IF, valueApply)
.addExpressions(PMMLUtil.createConstant(baseValue));

if(apply == null){
Expand Down
Loading

0 comments on commit 346915c

Please sign in to comment.