Skip to content

Commit

Permalink
WIP: Introduced a feature flag /session flag to enable disable Stats …
Browse files Browse the repository at this point in the history
…propagate feature.
  • Loading branch information
ScrapCodes committed Jun 10, 2024
1 parent 1f9afa3 commit 3444a8d
Show file tree
Hide file tree
Showing 15 changed files with 179 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ public boolean isDebug()

public Duration getClientRequestTimeout()
{
return Duration.valueOf("1h");
return clientRequestTimeout;
}

public boolean isCompressionDisabled()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ public final class SystemSessionProperties
public static final String DEFAULT_VIEW_SECURITY_MODE = "default_view_security_mode";
public static final String JOIN_PREFILTER_BUILD_SIDE = "join_prefilter_build_side";
public static final String OPTIMIZER_USE_HISTOGRAMS = "optimizer_use_histograms";
public static final String ENABLE_SCALAR_FUNCTION_STATS_PROPAGATION = "enable_scalar_function_stats_propagation";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -1937,6 +1938,10 @@ public SystemSessionProperties(
booleanProperty(OPTIMIZER_USE_HISTOGRAMS,
"whether or not to use histograms in the CBO",
featuresConfig.isUseHistograms(),
false),
booleanProperty(ENABLE_SCALAR_FUNCTION_STATS_PROPAGATION,
"whether or not to respect stats propagation annotation for scalar functions (or UDF)",
featuresConfig.isEnabledScalarFunctionStatsPropagation(),
false));
}

Expand Down Expand Up @@ -3229,4 +3234,9 @@ public static boolean shouldOptimizerUseHistograms(Session session)
{
return session.getSystemProperty(OPTIMIZER_USE_HISTOGRAMS, Boolean.class);
}

public static boolean shouldEnableScalarFunctionStatsPropagation(Session session)
{
return session.getSystemProperty(ENABLE_SCALAR_FUNCTION_STATS_PROPAGATION, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ public PlanNodeStatsEstimate getStats(PlanNode node)
PlanNodeStatsEstimate stats = cache.get(node);
if (stats != null) {
session.getPlanNodeStatsMap().put(node.getId(), stats);
log.info("stats " + stats + " for node: " + node + " found in cache.");
return stats;
}
stats = statsCalculator.calculateStats(node, this, lookup, session, types);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ private static class CostEstimator
public PlanCostEstimate visitPlan(PlanNode node, Void context)
{
// TODO implement cost estimates for all plan nodes
System.out.println("visitPlan = " + node + " cost unknown!");
return PlanCostEstimate.unknown();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
*/
package com.facebook.presto.cost;

import com.facebook.presto.FullConnectorSession;
import com.facebook.presto.Session;
import com.facebook.presto.SystemSessionProperties;
import com.facebook.presto.common.function.OperatorType;
import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.common.type.Type;
Expand Down Expand Up @@ -127,13 +129,12 @@ public VariableStatsEstimate visitCall(CallExpression call, Void context)
}

FunctionMetadata functionMetadata = metadata.getFunctionAndTypeManager().getFunctionMetadata(call.getFunctionHandle());
System.out.println("functionMetadata arg types = " + functionMetadata.getArgumentTypes() + " functionMetadata getReturnType= " + functionMetadata.getReturnType());
if (functionMetadata.getOperatorType().map(OperatorType::isArithmeticOperator).orElse(false)) {
return computeArithmeticBinaryStatistics(call, context);
}

RowExpression value = new RowExpressionOptimizer(metadata).optimize(call, OPTIMIZED, session);
System.out.println("RowExpression = " + value);

if (isNull(value)) {
return nullStatsEstimate();
}
Expand All @@ -147,13 +148,15 @@ public VariableStatsEstimate visitCall(CallExpression call, Void context)
return computeCastStatistics(call, context);
}

if (functionMetadata.getStatsHeader().isPresent()) {
if (functionMetadata.getStatsHeader().isPresent() &&
SystemSessionProperties.shouldEnableScalarFunctionStatsPropagation(((FullConnectorSession) session).getSession())) {
return computeCallStatistics(call, context, functionMetadata.getStatsHeader().get());
}
else {
System.out.println("Stats not found for func: " + functionMetadata.getName() + " " + call);
}
// by default propagate source stats of first col.
// return propagateCallSourceStatistics(call, context);
return VariableStatsEstimate.unknown();
}

Expand Down Expand Up @@ -232,10 +235,22 @@ public VariableStatsEstimate visitSpecialForm(SpecialFormExpression specialForm,
private StatisticRange processDistinctValueCountAndRange(CallExpression call, Void context, PropagateSourceStats op)
{
StatisticRange s = StatisticRange.empty();
double ndv = -1;
for (int i = 0; i < call.getArguments().size(); i++) {
VariableStatsEstimate sourceStats = call.getArguments().get(i).accept(this, context);
if (!sourceStats.isUnknown() && !Double.isNaN(sourceStats.getDistinctValuesCount())) {
if (!sourceStats.isUnknown() && isFinite(sourceStats.getDistinctValuesCount())) {
switch (op) {
case MAX_TYPE_WIDTH:
TypeSignature typeSignature = call.getArguments().get(i).getType().getTypeSignature();
if (typeSignature.getTypeSignatureBase().hasStandardType() && typeSignature.getTypeSignatureBase().getStandardTypeBase().equals(StandardTypes.VARCHAR)) {
for (TypeSignatureParameter t : typeSignature.getParameters()) {
Long longLiteral = t.getLongLiteral();
if (longLiteral > 0) {
ndv = Math.max(ndv, longLiteral);
}
}
}
break;
case MAX:
s = s.addAndMaxDistinctValues(sourceStats.statisticRange());
break;
Expand All @@ -245,15 +260,18 @@ private StatisticRange processDistinctValueCountAndRange(CallExpression call, Vo
}
}
}
if (s.isEmpty() && isFinite(ndv) && ndv > 0.0) {
s = new StatisticRange(0, ndv, ndv); // This would be the case of MAX_TYPE_WIDTH
}
return s;
}

private double processNullFraction(CallExpression call, Void context, PropagateSourceStats op)
{
double s = NaN;
double s = -10;
for (int i = 0; i < call.getArguments().size(); i++) {
VariableStatsEstimate sourceStats = call.getArguments().get(i).accept(this, context);
if (!sourceStats.isUnknown() && !Double.isNaN(sourceStats.getNullsFraction())) {
if (!sourceStats.isUnknown() && isFinite(sourceStats.getNullsFraction())) {
switch (op) {
case MAX:
s = max(s, sourceStats.getNullsFraction());
Expand All @@ -264,15 +282,18 @@ private double processNullFraction(CallExpression call, Void context, PropagateS
}
}
}
return s;
if (s > 0.0) {
return s;
}
return NaN;
}

private double processAvgRowSize(CallExpression call, Void context, PropagateSourceStats op)
{
double s = NaN;
double s = -10;
for (int i = 0; i < call.getArguments().size(); i++) {
VariableStatsEstimate sourceStats = call.getArguments().get(i).accept(this, context);
if (!sourceStats.isUnknown() && !Double.isNaN(sourceStats.getAverageRowSize())) {
if (!sourceStats.isUnknown() && isFinite(sourceStats.getAverageRowSize())) {
switch (op) {
case MAX:
s = max(s, sourceStats.getAverageRowSize());
Expand All @@ -283,7 +304,10 @@ private double processAvgRowSize(CallExpression call, Void context, PropagateSou
}
}
}
return s;
if (s > 0.0) {
return s;
}
return NaN;
}

private VariableStatsEstimate computeCallStatistics(CallExpression call, Void context, ScalarStatsHeader statsHeader)
Expand All @@ -299,8 +323,8 @@ private VariableStatsEstimate computeCallStatistics(CallExpression call, Void co
// TODO: handle histograms.
for (Map.Entry<Integer, ScalarPropagateSourceStats> entry : statsHeader.getStatsResolver().entrySet()) {
ScalarPropagateSourceStats scalarPropagateSourceStats = entry.getValue();
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
if (scalarPropagateSourceStats.propagateAllStats()) {
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
distinctValuesCount = sourceStats.getDistinctValuesCount();
min = sourceStats.getLowValue();
max = sourceStats.getHighValue();
Expand All @@ -310,36 +334,19 @@ private VariableStatsEstimate computeCallStatistics(CallExpression call, Void co
// distinct value count
switch (scalarPropagateSourceStats.distinctValueCount()) {
case SOURCE_STATS:
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
distinctValuesCount = sourceStats.getDistinctValuesCount();
break;
case ROW_COUNT:
distinctValuesCount = input.getOutputRowCount();
break;
case MAX_TYPE_WIDTH:
TypeSignature typeSignature = call.getArguments().get(entry.getKey()).getType().getTypeSignature();
if (typeSignature.getTypeSignatureBase().hasStandardType() && typeSignature.getTypeSignatureBase().getStandardTypeBase().equals(StandardTypes.VARCHAR)) {
double varcharTypeWidth = 0.0;
for (TypeSignatureParameter t : typeSignature.getParameters()) {
Long longLiteral = t.getLongLiteral();
varcharTypeWidth = Math.max(longLiteral, varcharTypeWidth);
}
if (varcharTypeWidth != 0) {
distinctValuesCount = varcharTypeWidth;
}
}
if (!isFinite(distinctValuesCount) || distinctValuesCount == 0) {
distinctValuesCount = input.getOutputRowCount();
}
break;
case MAX:
case SUM:
statisticRange = processDistinctValueCountAndRange(call, context, scalarPropagateSourceStats.distinctValueCount());
}
// min, max can be estimated by distinct value count as well, but user provided hints/values override those.
switch (scalarPropagateSourceStats.minValue()) {
case SOURCE_STATS:
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
min = sourceStats.getLowValue();
break;
case MAX:
Expand All @@ -348,21 +355,9 @@ private VariableStatsEstimate computeCallStatistics(CallExpression call, Void co
}
switch (scalarPropagateSourceStats.maxValue()) {
case SOURCE_STATS:
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
max = sourceStats.getHighValue();
break;
case MAX_TYPE_WIDTH:
TypeSignature typeSignature = call.getArguments().get(entry.getKey()).getType().getTypeSignature();
if (typeSignature.getTypeSignatureBase().hasStandardType() && typeSignature.getTypeSignatureBase().getStandardTypeBase().equals(StandardTypes.VARCHAR)) {
double varcharTypeWidth = 0.0;
for (TypeSignatureParameter t : typeSignature.getParameters()) {
Long longLiteral = t.getLongLiteral();
varcharTypeWidth = Math.max(longLiteral, varcharTypeWidth);
}
if (varcharTypeWidth != 0) {
max = varcharTypeWidth;
}
}
case MAX_TYPE_WIDTH: // Handled as part of distinct value count
break;
case MAX:
case SUM:
Expand All @@ -371,7 +366,6 @@ private VariableStatsEstimate computeCallStatistics(CallExpression call, Void co
// Average row size
switch (scalarPropagateSourceStats.avgRowSize()) {
case SOURCE_STATS:
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
avgRowSize = sourceStats.getAverageRowSize();
break;
case MAX:
Expand All @@ -381,26 +375,34 @@ private VariableStatsEstimate computeCallStatistics(CallExpression call, Void co
// Null fraction
switch (scalarPropagateSourceStats.nullFraction()) {
case SOURCE_STATS:
VariableStatsEstimate sourceStats = call.getArguments().get(entry.getKey()).accept(this, context);
nullFraction = sourceStats.getNullsFraction();
break;
case MAX:
case SUM:
nullFraction = processNullFraction(call, context, scalarPropagateSourceStats.nullFraction());
}
}

// If min and max are set via propagate stats
if (isFinite(min) && isFinite(max)) {
statisticRange = new StatisticRange(min, max, distinctValuesCount);
}
// Constant values override any values.
if (isFinite(statsHeader.getNullFraction())) {
nullFraction = statsHeader.getNullFraction();
}
if (isFinite(statsHeader.getAvgRowSize())) {
avgRowSize = statsHeader.getAvgRowSize();
}
if (isFinite(statsHeader.getDistinctValuesCount())) {
distinctValuesCount = statsHeader.getDistinctValuesCount();
if (statsHeader.getDistinctValuesCount() == -1.0) {
distinctValuesCount = input.getOutputRowCount();
}
else {
distinctValuesCount = statsHeader.getDistinctValuesCount();
}
}
if (isFinite(min) && isFinite(max)) {
statisticRange = new StatisticRange(min, max, distinctValuesCount);
if (isFinite(statsHeader.getMin()) && isFinite(statsHeader.getMax())) {
statisticRange = new StatisticRange(statsHeader.getMin(), statsHeader.getMax(), distinctValuesCount);
}
sourceStatsSum = VariableStatsEstimate.builder().setStatisticsRange(statisticRange)
.setAverageRowSize(avgRowSize)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@

import com.facebook.presto.common.type.StandardTypes;
import com.facebook.presto.spi.function.ScalarFunction;
import com.facebook.presto.spi.function.ScalarFunctionConstantStats;
import com.facebook.presto.spi.function.ScalarPropagateSourceStats;
import com.facebook.presto.spi.function.SqlType;

import static com.facebook.presto.spi.function.PropagateSourceStats.SOURCE_STATS;
import static com.facebook.presto.spi.function.SqlFunctionVisibility.HIDDEN;

public final class CombineHashFunction
Expand All @@ -25,7 +28,10 @@ private CombineHashFunction() {}

@ScalarFunction(value = "combine_hash", visibility = HIDDEN)
@SqlType(StandardTypes.BIGINT)
public static long getHash(@SqlType(StandardTypes.BIGINT) long previousHashValue, @SqlType(StandardTypes.BIGINT) long value)
@ScalarFunctionConstantStats(avgRowSize = 8, distinctValuesCount = -1)
public static long getHash(
@ScalarPropagateSourceStats(nullFraction = SOURCE_STATS) @SqlType(StandardTypes.BIGINT) long previousHashValue,
@SqlType(StandardTypes.BIGINT) long value)
{
return (31 * previousHashValue + value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ public static double radians(@SqlType(StandardTypes.DOUBLE) double degrees)
@Description("a pseudo-random value")
@ScalarFunction(alias = "rand", deterministic = false)
@SqlType(StandardTypes.DOUBLE)
@ScalarFunctionConstantStats(avgRowSize = 8, nullFraction = 0.0, minValue = 0, maxValue = 1.0, distinctValuesCount = -1) // ndv of -1 = row count
public static double random()
{
return ThreadLocalRandom.current().nextDouble();
Expand Down
Loading

0 comments on commit 3444a8d

Please sign in to comment.