Skip to content

Commit

Permalink
Improved support for the 'EstimatorTransformer' transformation type
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jul 4, 2021
1 parent 772bcb0 commit df83171
Show file tree
Hide file tree
Showing 7 changed files with 284 additions and 47 deletions.
140 changes: 97 additions & 43 deletions src/main/java/sklego/meta/EstimatorTransformer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package sklego.meta;

import java.util.Collections;
import java.util.Iterator;
import java.util.List;

import org.dmg.pmml.DataType;
Expand All @@ -28,6 +29,8 @@
import org.dmg.pmml.OpType;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
Expand Down Expand Up @@ -55,78 +58,129 @@ public List<Feature> encodeFeatures(List<Feature> features, SkLearnEncoder encod
Estimator estimator = getEstimator();
String predictFunc = getPredictFunc();

if(!estimator.isSupervised()){
throw new IllegalArgumentException();
}

switch(predictFunc){
case "predict":
break;
default:
throw new IllegalArgumentException(predictFunc);
}

MiningFunction miningFunction = estimator.getMiningFunction();

Label label;

switch(miningFunction){
case CLASSIFICATION:
{
List<?> categories = ClassifierUtil.getClasses(estimator);

DataType dataType = TypeUtil.getDataType(categories, DataType.STRING);

label = new CategoricalLabel(null, dataType, categories);
}
break;
case REGRESSION:
{
label = new ContinuousLabel(null, DataType.DOUBLE);
}
break;
default:
throw new IllegalArgumentException();
Label label = null;

if(estimator.isSupervised()){
MiningFunction miningFunction = estimator.getMiningFunction();

switch(miningFunction){
case CLASSIFICATION:
{
List<?> categories = ClassifierUtil.getClasses(estimator);

DataType dataType = TypeUtil.getDataType(categories, DataType.STRING);

label = new CategoricalLabel(null, dataType, categories);
}
break;
case REGRESSION:
{
label = new ContinuousLabel(null, DataType.DOUBLE);
}
break;
default:
throw new IllegalArgumentException();
}
}

Schema schema = new Schema(encoder, label, features);

Model model = estimator.encode(schema);

DerivedOutputField finalOutputField = null;

Output output = model.getOutput();
if(output != null && output.hasOutputFields()){
List<OutputField> outputFields = output.getOutputFields();

outputFields.clear();
for(Iterator<OutputField> it = outputFields.iterator(); it.hasNext(); ){
OutputField outputField = it.next();

ResultFeature resultFeature = outputField.getResultFeature();
switch(resultFeature){
case PREDICTED_VALUE:
case TRANSFORMED_VALUE:
{
finalOutputField = encoder.createDerivedField(model, outputField, true);
}
break;
default:
break;
}

it.remove();
}
}

encoder.addTransformer(model);

FieldName name = createFieldName("estimator");
if(estimator.isSupervised()){
MiningFunction miningFunction = estimator.getMiningFunction();

switch(miningFunction){
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel)label;
if(finalOutputField != null){
throw new IllegalArgumentException();
}

OutputField predictedOutputField = ModelUtil.createPredictedField(name, OpType.CATEGORICAL, categoricalLabel.getDataType());
FieldName name = createFieldName(Estimator.FIELD_PREDICT);

DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);
switch(miningFunction){
case CLASSIFICATION:
{
CategoricalLabel categoricalLabel = (CategoricalLabel)label;

return Collections.singletonList(new CategoricalFeature(encoder, predictedField, categoricalLabel.getValues()));
}
case REGRESSION:
{
ContinuousLabel continuousLabel = (ContinuousLabel)label;
OutputField predictedOutputField = ModelUtil.createPredictedField(name, OpType.CATEGORICAL, categoricalLabel.getDataType());

OutputField predictedOutputField = ModelUtil.createPredictedField(name, OpType.CONTINUOUS, continuousLabel.getDataType());
DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);

DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);
return Collections.singletonList(new CategoricalFeature(encoder, predictedField, categoricalLabel.getValues()));
}
case REGRESSION:
{
ContinuousLabel continuousLabel = (ContinuousLabel)label;

return Collections.singletonList(new ContinuousFeature(encoder, predictedField));
}
default:
OutputField predictedOutputField = ModelUtil.createPredictedField(name, OpType.CONTINUOUS, continuousLabel.getDataType());

DerivedOutputField predictedField = encoder.createDerivedField(model, predictedOutputField, false);

return Collections.singletonList(new ContinuousFeature(encoder, predictedField));
}
default:
throw new IllegalArgumentException();
}
} else

{
if(finalOutputField == null){
throw new IllegalArgumentException();
}

OpType opType = finalOutputField.getOpType();
switch(opType){
case CATEGORICAL:
{
DataType dataType = finalOutputField.getDataType();

switch(dataType){
case BOOLEAN:
return Collections.singletonList(new BooleanFeature(encoder, finalOutputField));
default:
throw new IllegalArgumentException();
}
}
case CONTINUOUS:
{
return Collections.singletonList(new ContinuousFeature(encoder, finalOutputField));
}
default:
throw new IllegalArgumentException();
}
}
}

Expand Down
5 changes: 5 additions & 0 deletions src/test/java/org/jpmml/sklearn/SkLegoTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,9 @@ public void evaluateEstimatorTransformerAudit() throws Exception {
public void evaluateEstimatorTransformerAuto() throws Exception {
evaluate("EstimatorTransformer", AUTO);
}

@Test
public void evaluateEstimatorTransformerIris() throws Exception {
evaluate("EstimatorTransformer", IRIS);
}
}
151 changes: 151 additions & 0 deletions src/test/resources/csv/EstimatorTransformerIris.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
Species,probability(setosa),probability(versicolor),probability(virginica)
setosa,0.9816557857088485,0.01834419979666306,1.4494488439257807e-08
setosa,0.9714678810353109,0.028532088799475375,3.0165213672777654e-08
setosa,0.9853356464276457,0.014664341242283531,1.2330070627928572e-08
setosa,0.9761577341733771,0.023842226156299592,3.967032345834187e-08
setosa,0.9852874553149072,0.014712532683384655,1.2001708066648066e-08
setosa,0.9703314981335994,0.029668427909907327,7.395649336987421e-08
setosa,0.9868231210617859,0.013176858970905429,1.9967308815766944e-08
setosa,0.976239893123865,0.023760079156014795,2.7720120096442044e-08
setosa,0.979711986944132,0.020287982483750452,3.057211769991342e-08
setosa,0.9688904264837149,0.031109541803075392,3.1713209669470587e-08
setosa,0.9763217723924839,0.02367820823775132,1.9369764741261784e-08
setosa,0.9752961106083841,0.024703845457045227,4.393457077858676e-08
setosa,0.9743381757383316,0.025661802757077124,2.1504591358140706e-08
setosa,0.9919051635702665,0.008094832541629796,3.8881037891440475e-09
setosa,0.9865181622022166,0.013481834549658114,3.2481253777803826e-09
setosa,0.9849760187813261,0.015023966444962745,1.4773711023173906e-08
setosa,0.9880173543037245,0.011982636437728457,9.258547230436284e-09
setosa,0.9814138303007075,0.018586149939460502,1.9759832129584387e-08
setosa,0.9562841148546888,0.04371581617575972,6.896955163589114e-08
setosa,0.9840378646898743,0.015962114678716076,2.0631409638948603e-08
setosa,0.946386554482804,0.053613358676689364,8.684050672746253e-08
setosa,0.9816465646039694,0.018353402423358284,3.297267225004004e-08
setosa,0.9959759159104763,0.004024082773163526,1.3163602431588677e-09
setosa,0.9520745760751353,0.0479251869652793,2.3695958521685349e-07
setosa,0.951766340106597,0.04823345339881755,2.0649458544595049e-07
setosa,0.951191323929018,0.048808589164161205,8.690682082660468e-08
setosa,0.9694637141990881,0.030536199360978532,8.643993326462277e-08
setosa,0.9747444206657102,0.02525555425744081,2.5076849055340903e-08
setosa,0.9771484596659553,0.022851522844932203,1.7489112433553372e-08
setosa,0.9710847906348397,0.02891515083941491,5.8525745319197615e-08
setosa,0.9640751387781432,0.035924790790794234,7.043106263314677e-08
setosa,0.964650969077886,0.03534897321094549,5.7711168624619883e-08
setosa,0.9883249736672969,0.01167501924300681,7.089696413033195e-09
setosa,0.9889932972775943,0.011006697367499235,5.354906296990481e-09
setosa,0.9684855022020719,0.03151445457182609,4.322610193605222e-08
setosa,0.984508577055292,0.015491414932751007,8.011956896289525e-09
setosa,0.978722409508271,0.021277580795340707,9.696388178252566e-09
setosa,0.9867864273051586,0.013213564151125583,8.543715698336522e-09
setosa,0.985756108974816,0.014243875521581732,1.5503602113155224e-08
setosa,0.9739178090094025,0.026082162458747303,2.8531850076435938e-08
setosa,0.9865252954726859,0.01347469312742159,1.1399892503240025e-08
setosa,0.9618503706914081,0.03814956354454232,6.57640496532111e-08
setosa,0.9889524909060696,0.011047497825858765,1.1268071421744312e-08
setosa,0.9723423594284019,0.027657503362656925,1.3720894131767065e-07
setosa,0.9602049275106167,0.039794849211171074,2.2327821214454493e-07
setosa,0.9736620606869596,0.02633789935496568,3.995807460793329e-08
setosa,0.9802300256992462,0.01976994884844664,2.5452307143335822e-08
setosa,0.983242614087709,0.01675736577307926,2.0139211853404276e-08
setosa,0.9784345723822081,0.021565408803230546,1.8814561125619823e-08
setosa,0.9785047204539143,0.021495260216324197,1.9329761346194388e-08
versicolor,0.0021403945769165773,0.8738158219342768,0.12404378348880664
versicolor,0.005855444925963602,0.8596281943352834,0.134516360738753
versicolor,0.0010676358944998418,0.7249288602981188,0.2740035038073816
versicolor,0.015508298666405313,0.9395772692944606,0.044914432039134033
versicolor,0.0023932317305477703,0.8150842483231274,0.18252251994632485
versicolor,0.007023758465687699,0.8601131321827501,0.1328631093515622
versicolor,0.0037955119868477633,0.7164684696692801,0.2797360183438721
versicolor,0.1484312994701469,0.8485145547809985,0.0030541457488546236
versicolor,0.002794519935100642,0.896549711949919,0.10065576811498042
versicolor,0.041706697551368785,0.9116853160669561,0.04660798638167505
versicolor,0.04990404355502327,0.9435103174271131,0.006585639017863489
versicolor,0.015273471351751814,0.898708762279102,0.08601776636914613
versicolor,0.009144624004484177,0.976508240605561,0.014347135389954814
versicolor,0.003068090562929333,0.7792555944306628,0.21767631500640797
versicolor,0.07486312136484996,0.9146797145136091,0.010457164121540966
versicolor,0.005317291444456054,0.9262363298295911,0.06844637872595302
versicolor,0.00876711332316279,0.7747077853972829,0.21652510127955432
versicolor,0.016525369739786643,0.9651003883032325,0.01837424195698075
versicolor,0.0018240619157884387,0.8009057227197649,0.19727021536444656
versicolor,0.024089091928456558,0.9593044133384547,0.0166064947330888
virginica,0.0023229148930482674,0.4402682626074164,0.5574088224995354
versicolor,0.016929025303059,0.9566532330521299,0.026417741644811112
versicolor,0.0007201715632936571,0.5962646274047237,0.40301520103198263
versicolor,0.00305730811936091,0.8598854877102081,0.1370572041704309
versicolor,0.007119208433970069,0.942863911946568,0.05001687961946182
versicolor,0.005109715556765625,0.9199902264409299,0.07490005800230462
versicolor,0.0011273020756997403,0.8013892246818846,0.19748347324241572
virginica,0.000582776971329284,0.4810518433455018,0.518365379683169
versicolor,0.005518532268280226,0.8130360603055276,0.18144540742619225
versicolor,0.0619793042292597,0.9346376061859181,0.003383089584822318
versicolor,0.029299099688802212,0.9570811934722382,0.013619706838959612
versicolor,0.037338735129976354,0.9551239580606998,0.007537306809323797
versicolor,0.02531544474114868,0.9563327123271644,0.01835184293168686
virginica,0.00045256277126416554,0.3498749563772026,0.6496724808515332
versicolor,0.010289726981110596,0.7509745860262257,0.23873568699266373
versicolor,0.010051279241006447,0.7886417973646417,0.201306923394352
versicolor,0.0022792677877330747,0.8050342668085643,0.19268646540370257
versicolor,0.0027817399420236497,0.9130848012324886,0.0841334588254878
versicolor,0.027203379808996484,0.9283722011964929,0.04442441899451059
versicolor,0.020025588288226565,0.9379580866258838,0.042016325085889726
versicolor,0.008764274623179687,0.8978791816590665,0.09335654371775376
versicolor,0.004673213219793229,0.8282197900027163,0.16710699677749036
versicolor,0.01767512032088177,0.9569022503536794,0.025422629325438845
versicolor,0.12226522395919999,0.8746852610475213,0.0030495149932786153
versicolor,0.014548475854451102,0.9203832439939879,0.06506828015156096
versicolor,0.020089554710681065,0.9379905361096673,0.04191990917965157
versicolor,0.01719426995771089,0.9253258747679578,0.0574798552743313
versicolor,0.008548108285653184,0.9350340117253453,0.05641787998900161
versicolor,0.24462601391555996,0.7540827979016416,0.0012911881827984805
versicolor,0.01926811259694251,0.935919703394333,0.04481218400872456
virginica,9.124836009404962e-07,0.003921820468955399,0.9960772670474437
virginica,0.0002448021954449444,0.16269814168626087,0.8370570561182941
virginica,2.502977335859906e-06,0.025587447564589305,0.9744100494580749
virginica,3.1475957569273263e-05,0.08177326548895592,0.918195258553475
virginica,3.7701474704083797e-06,0.017464241324556744,0.9825319885279729
virginica,5.582928292198184e-08,0.004643638669967625,0.9953563055007495
versicolor,0.005812761044703366,0.5140140766654759,0.48017316228982077
virginica,6.265958172506613e-07,0.021364659784438376,0.9786347136197445
virginica,5.270465676137816e-06,0.05332549468417085,0.9466692348501529
virginica,6.602687101084285e-07,0.005747702941620662,0.9942516367896693
virginica,0.00030447966374539504,0.210478382089449,0.7892171382468056
virginica,7.3333489902989e-05,0.13735442976706572,0.8625722367430312
virginica,2.1447309004997317e-05,0.06527793308837854,0.9347006196026164
virginica,0.0002313103553266701,0.14534619700606807,0.8544224926386051
virginica,6.970297055292738e-05,0.043543374136845905,0.9563869228926012
virginica,5.206781959454568e-05,0.05405988310408923,0.9458880490763162
virginica,5.6033544132025296e-05,0.12298684115555539,0.8769571253003127
virginica,8.540085102953086e-08,0.0035653494946546334,0.9964345651044942
virginica,2.782344161403108e-09,0.000990550040380482,0.9990094471772754
virginica,0.00039238183536130056,0.4520370660173033,0.5475705521473354
virginica,5.650882662294913e-06,0.023859337187192377,0.9761350119301454
virginica,0.0006186990623690562,0.19059265317159116,0.8087886477660398
virginica,3.1653021192319614e-08,0.004659192108613048,0.9953407762383657
virginica,0.0005899177864991513,0.3930226099588011,0.6063874722546997
virginica,1.294182900898511e-05,0.038636456650029384,0.9613506015209615
virginica,4.8938786455221494e-06,0.05152126158864239,0.948473844532712
virginica,0.0010803464690480535,0.4564333808515562,0.5424862726793956
virginica,0.0010304258063547787,0.38541407005578443,0.6135555041378608
virginica,1.0747880081492008e-05,0.03637168487864255,0.9636175672412759
virginica,1.700682049407769e-05,0.1420161965315378,0.8579667966479683
virginica,1.0724581900914203e-06,0.029208438355116443,0.9707904891866935
virginica,6.194326024064337e-07,0.017239392452446812,0.9827599881149507
virginica,7.95631573087409e-06,0.027286657375518025,0.972705386308751
virginica,0.0005322830704030531,0.47564587390488194,0.5238218430247149
virginica,6.309977425142072e-05,0.18882726083611304,0.8111096393896355
virginica,3.9687751852223873e-07,0.011743431594648046,0.9882561715278334
virginica,1.1729672439925794e-05,0.017349186293852063,0.9826390840337081
virginica,6.818461339075385e-05,0.11958642495824089,0.8803453904283685
virginica,0.001632636193317133,0.44042671452375737,0.5579406492829255
virginica,3.9999595356400716e-05,0.09347541298929733,0.9064845874153463
virginica,6.36388768850334e-06,0.020299757539928888,0.9796938785723825
virginica,0.00010050364413138114,0.12056717906814068,0.8793323172877279
virginica,0.0002448021954449444,0.16269814168626087,0.8370570561182941
virginica,2.0678983469784076e-06,0.012598957562164597,0.9873989745394883
virginica,3.846534619828258e-06,0.012113291360322596,0.9878828621050575
virginica,5.6405396660988635e-05,0.0800921189542684,0.9198514756490705
virginica,0.00022888958922116654,0.25189332017782856,0.7478777902329502
virginica,0.0001394711424956285,0.15714494275998056,0.8427155860975237
virginica,4.6050051525271566e-05,0.03846824487707604,0.9614857050713987
virginica,0.0004792049830317949,0.2350528653023213,0.7644679297146469
Loading

0 comments on commit df83171

Please sign in to comment.