Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(isthmus): improved Calcite support for Substrait Aggregate rels #214

Merged
merged 12 commits into from
Jan 11, 2024
180 changes: 132 additions & 48 deletions core/src/main/java/io/substrait/dsl/SubstraitBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,17 @@ public SubstraitBuilder(SimpleExtension.ExtensionCollection extensions) {
}

// Relations
public Aggregate.Measure measure(AggregateFunctionInvocation aggFn) {
return Aggregate.Measure.builder().function(aggFn).build();
}

public Aggregate.Measure measure(AggregateFunctionInvocation aggFn, Expression preMeasureFilter) {
return Aggregate.Measure.builder().function(aggFn).preMeasureFilter(preMeasureFilter).build();
}

public Aggregate aggregate(
Function<Rel, Aggregate.Grouping> groupingFn,
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Rel input) {
vbarua marked this conversation as resolved.
Show resolved Hide resolved
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
groupingFn.andThen(g -> Stream.of(g).collect(Collectors.toList()));
Expand All @@ -64,7 +72,7 @@ public Aggregate aggregate(

public Aggregate aggregate(
Function<Rel, Aggregate.Grouping> groupingFn,
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Rel.Remap remap,
Rel input) {
Function<Rel, List<Aggregate.Grouping>> groupingsFn =
Expand All @@ -74,14 +82,11 @@ public Aggregate aggregate(

private Aggregate aggregate(
Function<Rel, List<Aggregate.Grouping>> groupingsFn,
Function<Rel, List<AggregateFunctionInvocation>> measuresFn,
Function<Rel, List<Aggregate.Measure>> measuresFn,
Optional<Rel.Remap> remap,
Rel input) {
var groupings = groupingsFn.apply(input);
var measures =
measuresFn.apply(input).stream()
.map(m -> Aggregate.Measure.builder().function(m).build())
.collect(java.util.stream.Collectors.toList());
var measures = measuresFn.apply(input);
return Aggregate.builder()
.groupings(groupings)
.measures(measures)
Expand Down Expand Up @@ -389,6 +394,11 @@ public List<Expression.SortField> sortFields(Rel input, int... indexes) {
.collect(java.util.stream.Collectors.toList());
}

public Expression.SortField sortField(
Expression expression, Expression.SortDirection sortDirection) {
return Expression.SortField.builder().expr(expression).direction(sortDirection).build();
}

public SwitchClause switchClause(Expression.Literal condition, Expression then) {
return SwitchClause.builder().condition(condition).then(then).build();
}
Expand Down Expand Up @@ -422,76 +432,150 @@ public Aggregate.Grouping grouping(Rel input, int... indexes) {
return Aggregate.Grouping.builder().addAllExpressions(columns).build();
}

public AggregateFunctionInvocation count(Rel input, int field) {
public Aggregate.Grouping grouping(Expression... expressions) {
return Aggregate.Grouping.builder().addExpressions(expressions).build();
}

public Aggregate.Measure count(Rel input, int field) {
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_AGGREGATE_GENERIC, "count:any"));
return AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(R.I64)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build();
return measure(
AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(R.I64)
.declaration(declaration)
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}

public Aggregate.Measure min(Rel input, int field) {
return min(fieldReference(input, field));
}

public AggregateFunctionInvocation min(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// min output is always nullable
public Aggregate.Measure min(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "min", TypeCreator.asNullable(inputType));
expr,
"min",
// min output is always nullable
TypeCreator.asNullable(expr.getType()));
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added arithmetic aggregate functions that take expression directly.

}

public AggregateFunctionInvocation max(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// max output is always nullable
public Aggregate.Measure max(Rel input, int field) {
return max(fieldReference(input, field));
}

public Aggregate.Measure max(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "max", TypeCreator.asNullable(inputType));
expr,
"max",
// max output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public AggregateFunctionInvocation avg(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// avg output is always nullable
public Aggregate.Measure avg(Rel input, int field) {
return avg(fieldReference(input, field));
}

public Aggregate.Measure avg(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "avg", TypeCreator.asNullable(inputType));
expr,
"avg",
// avg output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public Aggregate.Measure sum(Rel input, int field) {
return sum(fieldReference(input, field));
}

public AggregateFunctionInvocation sum(Rel input, int field) {
Type inputType = input.getRecordType().fields().get(field);
// sum output is always nullable
public Aggregate.Measure sum(Expression expr) {
return singleArgumentArithmeticAggregate(
input, field, "sum", TypeCreator.asNullable(inputType));
expr,
"sum",
// sum output is always nullable
TypeCreator.asNullable(expr.getType()));
}

public AggregateFunctionInvocation sum0(Rel input, int field) {
// sum0 output is always NOT NULL I64
return singleArgumentArithmeticAggregate(input, field, "sum0", R.I64);
public Aggregate.Measure sum0(Rel input, int field) {
return sum(fieldReference(input, field));
}

private AggregateFunctionInvocation singleArgumentArithmeticAggregate(
Rel input, int field, String functionName, Type outputType) {
Type inputType = input.getRecordType().fields().get(field);
String typeString = inputType.accept(ToTypeString.INSTANCE);
public Aggregate.Measure sum0(Expression expr) {
return singleArgumentArithmeticAggregate(
expr,
"sum0",
// sum0 output is always NOT NULL I64
R.I64);
}

private Aggregate.Measure singleArgumentArithmeticAggregate(
Expression expr, String functionName, Type outputType) {
String typeString = ToTypeString.apply(expr.getType());
var declaration =
extensions.getAggregateFunction(
SimpleExtension.FunctionAnchor.of(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("%s:%s", functionName, typeString)));
return AggregateFunctionInvocation.builder()
.arguments(fieldReferences(input, field))
.outputType(outputType)
.declaration(declaration)
// INITIAL_TO_RESULT is the most restrictive aggregation phase type,
// as it does not allow decomposition. Use it as the default for now.
// TODO: set this per function
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build();
return measure(
AggregateFunctionInvocation.builder()
.arguments(Arrays.asList(expr))
.outputType(outputType)
.declaration(declaration)
// INITIAL_TO_RESULT is the most restrictive aggregation phase type,
// as it does not allow decomposition. Use it as the default for now.
// TODO: set this per function
.aggregationPhase(Expression.AggregationPhase.INITIAL_TO_RESULT)
.invocation(Expression.AggregationInvocation.ALL)
.build());
}

// Scalar Functions

public Expression.ScalarFunctionInvocation negate(Expression expr) {
// output type of negate is the same as the input type
var outputType = expr.getType();
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC,
String.format("negate:%s", ToTypeString.apply(outputType)),
outputType,
expr);
}

public Expression.ScalarFunctionInvocation add(Expression left, Expression right) {
return arithmeticFunction("add", left, right);
}

public Expression.ScalarFunctionInvocation subtract(Expression left, Expression right) {
return arithmeticFunction("substract", left, right);
}

public Expression.ScalarFunctionInvocation multiply(Expression left, Expression right) {
return arithmeticFunction("multiply", left, right);
}

public Expression.ScalarFunctionInvocation divide(Expression left, Expression right) {
return arithmeticFunction("divide", left, right);
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I needed to nest some functions within aggregate function calls for testing purposes, so I added some easy ones.


private Expression.ScalarFunctionInvocation arithmeticFunction(
String fname, Expression left, Expression right) {
var leftTypeStr = ToTypeString.apply(left.getType());
var rightTypeStr = ToTypeString.apply(right.getType());
var key = String.format("%s:%s_%s", fname, leftTypeStr, rightTypeStr);

var isOutputNullable = left.getType().nullable() || right.getType().nullable();
var outputType = left.getType();
outputType =
isOutputNullable
? TypeCreator.asNullable(outputType)
: TypeCreator.asNotNullable(outputType);

return scalarFn(DefaultExtensionCatalog.FUNCTIONS_ARITHMETIC, key, outputType, left, right);
}

public Expression.ScalarFunctionInvocation equal(Expression left, Expression right) {
return scalarFn(
DefaultExtensionCatalog.FUNCTIONS_COMPARISON, "equal:any_any", R.BOOLEAN, left, right);
Expand Down
6 changes: 5 additions & 1 deletion core/src/main/java/io/substrait/function/ToTypeString.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@
public class ToTypeString
extends ParameterizedTypeVisitor.ParameterizedTypeThrowsVisitor<String, RuntimeException> {

public static ToTypeString INSTANCE = new ToTypeString();
public static final ToTypeString INSTANCE = new ToTypeString();
vbarua marked this conversation as resolved.
Show resolved Hide resolved

public static String apply(Type type) {
return type.accept(INSTANCE);
}

private ToTypeString() {
super("Only type literals and parameterized types can be used in functions.");
Expand Down
Loading