Skip to content

Commit

Permalink
Merged version 1.3.15
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Jan 8, 2020
2 parents 677a9f6 + d6780ee commit 96d5802
Show file tree
Hide file tree
Showing 17 changed files with 186 additions and 90 deletions.
24 changes: 19 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,23 @@ JPMML-SparkML

Java library and command-line application for converting Apache Spark ML pipelines to PMML.

# Table of Contents #

* [Features](#features)
* [Prerequisites](#prerequisites)
* [Installation](#installation)
* [Library](#library)
* [Example application](#example-application)
* [Usage](#usage)
* [Library](#library-1)
* [Example application](#example-application-1)
* [Documentation](#documentation)
* [License](#license)
* [Additional information](#additional-information)

# Features #

* Supported Spark ML `PipelineStage` types:
* Supported pipeline stage types:
* Feature extractors, transformers and selectors:
* [`feature.Binarizer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/Binarizer.html)
* [`feature.Bucketizer`](https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/feature/Bucketizer.html)
Expand Down Expand Up @@ -75,7 +89,7 @@ Java library and command-line application for converting Apache Spark ML pipelin

# Installation #

## Library ##
### Library

JPMML-SparkML library JAR file (together with accompanying Java source and Javadocs JAR files) is released via [Maven Central Repository](https://repo1.maven.org/maven2/org/jpmml/).

Expand Down Expand Up @@ -104,7 +118,7 @@ JPMML-SparkML depends on the latest and greatest version of the [JPMML-Model](ht

This conflict is documented in [SPARK-15526](https://issues.apache.org/jira/browse/SPARK-15526). For possible resolutions, please switch from this README.md file to the README.md file of some earlier JPMML-SparkML development branch.

## Example application ##
### Example application

Enter the project root directory and build using [Apache Maven](https://maven.apache.org/):
```
Expand All @@ -117,7 +131,7 @@ The build produces two JAR files:

# Usage #

## Library ##
### Library

Fitting a Spark ML pipeline that only makes use of supported Transformer types:
```java
Expand Down Expand Up @@ -147,7 +161,7 @@ PMML pmml = new PMMLBuilder(schema, pipelineModel)
JAXBUtil.marshalPMML(pmml, new StreamResult(System.out));
```

## Example application ##
### Example application

The example application JAR file contains an executable class `org.jpmml.sparkml.Main`, which can be used to convert a pair of serialized `org.apache.spark.sql.types.StructType` and `org.apache.spark.ml.PipelineModel` objects to PMML.

Expand Down
10 changes: 5 additions & 5 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>jpmml-converter</artifactId>
<version>1.3.9</version>
<version>1.3.10</version>
</dependency>

<dependency>
Expand Down Expand Up @@ -116,7 +116,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<version>3.1.2</version>
<version>3.2.0</version>
<configuration>
<archive>
<manifest>
Expand All @@ -128,7 +128,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>3.1.0</version>
<version>3.1.1</version>
<configuration>
<javadocVersion>1.8</javadocVersion>
</configuration>
Expand Down Expand Up @@ -176,7 +176,7 @@
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>3.1.0</version>
<version>3.2.1</version>
<executions>
<execution>
<id>attach-sources</id>
Expand All @@ -198,7 +198,7 @@
<plugin>
<groupId>org.jacoco</groupId>
<artifactId>jacoco-maven-plugin</artifactId>
<version>0.8.4</version>
<version>0.8.5</version>
<executions>
<execution>
<id>pre-unit-test</id>
Expand Down
17 changes: 10 additions & 7 deletions src/main/java/org/jpmml/sparkml/ClassificationModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MapValues;
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Model;
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.CategoricalLabel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.Feature;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
Expand All @@ -56,33 +58,33 @@ public MiningFunction getMiningFunction(){
}

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder){
public List<OutputField> registerOutputFields(Label label, Model pmmlModel, SparkMLEncoder encoder){
T model = getTransformer();

CategoricalLabel categoricalLabel = (CategoricalLabel)label;

List<Integer> categories = LabelUtil.createTargetCategories(categoricalLabel.size());

List<OutputField> result = new ArrayList<>();

String predictionCol = model.getPredictionCol();

OutputField pmmlPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), OpType.CATEGORICAL, categoricalLabel.getDataType())
OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), OpType.CATEGORICAL, categoricalLabel.getDataType())
.setFinalResult(false);

result.add(pmmlPredictedField);
DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, true);

MapValues mapValues = PMMLUtil.createMapValues(pmmlPredictedField.getName(), categoricalLabel.getValues(), categories)
.setDataType(DataType.DOUBLE);

OutputField predictedField = new OutputField(FieldName.create(predictionCol), OpType.CONTINUOUS, DataType.DOUBLE)
OutputField predictedOutputField = new OutputField(FieldName.create(predictionCol), OpType.CONTINUOUS, DataType.DOUBLE)
.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
.setExpression(mapValues);

result.add(predictedField);
DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, true);

encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, categories));

List<OutputField> result = new ArrayList<>();

if(model instanceof HasProbabilityCol){
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol)model;

Expand All @@ -100,6 +102,7 @@ public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encode
features.add(new ContinuousFeature(encoder, probabilityField));
}

// XXX
encoder.putFeatures(probabilityCol, features);
}

Expand Down
17 changes: 8 additions & 9 deletions src/main/java/org/jpmml/sparkml/ClusteringModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
*/
package org.jpmml.sparkml;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import org.apache.spark.ml.Model;
Expand All @@ -31,6 +31,7 @@
import org.dmg.pmml.OpType;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.ResultFeature;
import org.jpmml.converter.DerivedOutputField;
import org.jpmml.converter.IndexFeature;
import org.jpmml.converter.Label;
import org.jpmml.converter.LabelUtil;
Expand All @@ -52,28 +53,26 @@ public MiningFunction getMiningFunction(){
}

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder){
public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model pmmlModel, SparkMLEncoder encoder){
T model = getTransformer();

List<Integer> clusters = LabelUtil.createTargetCategories(getNumberOfClusters());

List<OutputField> result = new ArrayList<>();

String predictionCol = model.getPredictionCol();

OutputField pmmlPredictedField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), OpType.CATEGORICAL, DataType.STRING)
OutputField pmmlPredictedOutputField = ModelUtil.createPredictedField(FieldName.create("pmml(" + predictionCol + ")"), OpType.CATEGORICAL, DataType.STRING)
.setFinalResult(false);

result.add(pmmlPredictedField);
DerivedOutputField pmmlPredictedField = encoder.createDerivedField(pmmlModel, pmmlPredictedOutputField, true);

OutputField predictedField = new OutputField(FieldName.create(predictionCol), OpType.CATEGORICAL, DataType.INTEGER)
OutputField predictedOutputField = new OutputField(FieldName.create(predictionCol), OpType.CATEGORICAL, DataType.INTEGER)
.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
.setExpression(new FieldRef(pmmlPredictedField.getName()));

result.add(predictedField);
DerivedOutputField predictedField = encoder.createDerivedField(pmmlModel, predictedOutputField, true);

encoder.putOnlyFeature(predictionCol, new IndexFeature(encoder, predictedField, clusters));

return result;
return Collections.emptyList();
}
}
10 changes: 5 additions & 5 deletions src/main/java/org/jpmml/sparkml/ConverterFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public TransformerConverter<?> newConverter(Transformer transformer){

converter = converterConstructor.newInstance(transformer);
} catch(ReflectiveOperationException roe){
throw new IllegalArgumentException(roe);
throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not supported", roe);
}

if(converter != null){
Expand Down Expand Up @@ -188,8 +188,8 @@ private void init(ClassLoader classLoader, Properties properties){
continue;
}

if(clazz == null || !(Transformer.class).isAssignableFrom(clazz)){
throw new IllegalArgumentException("Expected " + Transformer.class.getName() + " subclass, got " + (clazz != null ? clazz.getName() : null));
if(!(Transformer.class).isAssignableFrom(clazz)){
throw new IllegalArgumentException("Transformer class " + clazz.getName() + " is not a subclass of " + Transformer.class.getName());
} // End if

Class<? extends TransformerConverter<?>> converterClazz;
Expand All @@ -202,8 +202,8 @@ private void init(ClassLoader classLoader, Properties properties){
continue;
}

if(converterClazz == null || !(TransformerConverter.class).isAssignableFrom(converterClazz)){
throw new IllegalArgumentException("Expected " + TransformerConverter.class.getName() + " subclass, got " + (converterClazz != null ? converterClazz.getName() : null));
if(!(TransformerConverter.class).isAssignableFrom(converterClazz)){
throw new IllegalArgumentException("Transformer converter class " + converterClazz.getName() + " is not a subclass of " + TransformerConverter.class.getName());
}

ConverterFactory.converters.put(clazz, converterClazz);
Expand Down
7 changes: 6 additions & 1 deletion src/main/java/org/jpmml/sparkml/ExpressionTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,12 @@ private Object toSimpleObject(Object value){

static
private String formatMessage(Expression expression){
return "Spark SQL function \'" + String.valueOf(expression) + "\' (Java class " + (expression.getClass()).getName() + ") is not supported";

if(expression == null){
return null;
}

return "Spark SQL function \'" + expression + "\' (class " + (expression.getClass()).getName() + ") is not supported";
}

private static final Package javaLangPackage = Package.getPackage("java.lang");
Expand Down
21 changes: 5 additions & 16 deletions src/main/java/org/jpmml/sparkml/ModelConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
import org.dmg.pmml.MiningFunction;
import org.dmg.pmml.Output;
import org.dmg.pmml.OutputField;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.BooleanFeature;
import org.jpmml.converter.CategoricalFeature;
import org.jpmml.converter.CategoricalLabel;
Expand Down Expand Up @@ -160,7 +159,7 @@ public Schema encodeSchema(SparkMLEncoder encoder){
return result;
}

public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder){
public List<OutputField> registerOutputFields(Label label, org.dmg.pmml.Model model, SparkMLEncoder encoder){
return null;
}

Expand All @@ -171,25 +170,15 @@ public org.dmg.pmml.Model registerModel(SparkMLEncoder encoder){

org.dmg.pmml.Model model = encodeModel(schema);

List<OutputField> sparkOutputFields = registerOutputFields(label, encoder);
List<OutputField> sparkOutputFields = registerOutputFields(label, model, encoder);
if(sparkOutputFields != null && sparkOutputFields.size() > 0){
Output output;
org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);

if(model instanceof MiningModel){
MiningModel miningModel = (MiningModel)model;

org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(miningModel);

output = ModelUtil.ensureOutput(finalModel);
} else

{
output = ModelUtil.ensureOutput(model);
}
Output output = ModelUtil.ensureOutput(finalModel);

List<OutputField> outputFields = output.getOutputFields();

outputFields.addAll(0, sparkOutputFields);
outputFields.addAll(sparkOutputFields);
}

return model;
Expand Down
28 changes: 18 additions & 10 deletions src/main/java/org/jpmml/sparkml/PMMLBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ public PMMLBuilder(StructType schema, PipelineModel pipelineModel){
setPipelineModel(pipelineModel);
}

public PMMLBuilder(StructType schema, PipelineStage pipelineStage){
throw new IllegalArgumentException("Expected a fitted pipeline model (class " + PipelineModel.class.getName() + "), got a pipeline stage (" + (pipelineStage != null ? ("class " + (pipelineStage.getClass()).getName()) : null) + ")");
}

public PMML build(){
StructType schema = getSchema();
PipelineModel pipelineModel = getPipelineModel();
Expand Down Expand Up @@ -139,7 +143,7 @@ public PMML build(){
} else

{
throw new IllegalArgumentException("Expected a " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + " instance, got " + converter);
throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
}
}

Expand All @@ -159,20 +163,24 @@ public PMML build(){

{
throw new IllegalArgumentException("Expected a pipeline with one or more models, got a pipeline with zero models");
}
} // End if

for(FieldName postProcessorName : postProcessorNames){
DerivedField derivedField = derivedFields.get(postProcessorName);
if(postProcessorNames.size() > 0){
org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);

encoder.removeDerivedField(postProcessorName);
Output output = ModelUtil.ensureOutput(finalModel);

Output output = ModelUtil.ensureOutput(model);
for(FieldName postProcessorName : postProcessorNames){
DerivedField derivedField = derivedFields.get(postProcessorName);

OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType())
.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
.setExpression(derivedField.getExpression());
encoder.removeDerivedField(postProcessorName);

output.addOutputFields(outputField);
OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType())
.setResultFeature(ResultFeature.TRANSFORMED_VALUE)
.setExpression(derivedField.getExpression());

output.addOutputFields(outputField);
}
}

PMML pmml = encoder.encodePMML(model);
Expand Down
Loading

0 comments on commit 96d5802

Please sign in to comment.