Skip to content

Commit

Permalink
Simplify join to time spine logic WIP - need to validate snapshots
Browse files Browse the repository at this point in the history
  • Loading branch information
courtneyholcomb committed Nov 21, 2024
1 parent 2f17fa0 commit 1de52ca
Show file tree
Hide file tree
Showing 49 changed files with 308 additions and 338 deletions.
150 changes: 50 additions & 100 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime as dt
import logging
from collections import OrderedDict
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar, Union
from typing import Callable, Dict, FrozenSet, List, Optional, Sequence, Set, Tuple, TypeVar

from dbt_semantic_interfaces.enum_extension import assert_values_exhausted
from dbt_semantic_interfaces.protocols.metric import MetricInputMeasure, MetricType
Expand Down Expand Up @@ -39,6 +39,7 @@
from metricflow_semantics.specs.where_filter.where_filter_spec import WhereFilterSpec
from metricflow_semantics.sql.sql_join_type import SqlJoinType
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.time.granularity import ExpandedTimeGranularity
from metricflow_semantics.time.time_constants import ISO8601_PYTHON_FORMAT, ISO8601_PYTHON_TS_FORMAT
from metricflow_semantics.time.time_spine_source import TIME_SPINE_DATA_SET_DESCRIPTION, TimeSpineSource
from typing_extensions import override
Expand Down Expand Up @@ -349,27 +350,28 @@ def _make_time_spine_data_set(
apply_group_by = True
for agg_time_dimension_spec in required_time_spine_specs:
column_alias = self._column_association_resolver.resolve_spec(agg_time_dimension_spec).column_name
# If the requested granularity is the same as the granularity of the spine, do a direct select.
agg_time_grain = agg_time_dimension_spec.time_granularity
if (
agg_time_grain.base_granularity == time_spine_source.base_granularity
and not agg_time_grain.is_custom_granularity
):
expr: SqlExpressionNode = base_column_expr
# If there is a date_part selected, apply an EXTRACT() to the base column.
if agg_time_dimension_spec.date_part:
expr: SqlExpressionNode = SqlExtractExpression.create(
date_part=agg_time_dimension_spec.date_part, arg=base_column_expr
)
# If the requested granularity is the same as the granularity of the spine, do a direct select.
elif agg_time_grain == ExpandedTimeGranularity.from_time_granularity(time_spine_source.base_granularity):
expr = base_column_expr
apply_group_by = False
# If the granularity is custom, select the appropriate custom granularity column.
elif agg_time_grain.is_custom_granularity:
# If any dimensions require a custom granularity, select the appropriate column.
for custom_granularity in time_spine_source.custom_granularities:
expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_table_alias, column_name=custom_granularity.parsed_column_name
)
# Otherwise, apply the requested standard granularity using a DATE_TRUNC() on the base column.
else:
# If any dimensions require a different standard granularity, apply a DATE_TRUNC() to the base column.
expr = SqlDateTruncExpression.create(
time_granularity=agg_time_grain.base_granularity, arg=base_column_expr
)
select_columns += (SqlSelectColumn(expr=expr, column_alias=column_alias),)
# TODO: also handle date part.

output_instance_set = InstanceSet(
time_dimension_instances=tuple(
Expand All @@ -383,7 +385,7 @@ def _make_time_spine_data_set(
associated_columns=(self._column_association_resolver.resolve_spec(spec),),
spec=spec,
)
for spec in required_time_spine_specs
for spec in agg_time_dimension_specs
]
)
)
Expand Down Expand Up @@ -1391,7 +1393,6 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
join_on_time_dimension_sample = included_metric_time_instances[0].spec
else:
join_on_time_dimension_sample = agg_time_dimension_instances[0].spec

agg_time_dimension_instance_for_join = self._choose_instance_for_time_spine_join(
[
instance
Expand All @@ -1400,11 +1401,13 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
and instance.spec.entity_links == join_on_time_dimension_sample.entity_links
]
)
if agg_time_dimension_instance_for_join not in agg_time_dimension_instances:
agg_time_dimension_instances = (agg_time_dimension_instance_for_join,) + agg_time_dimension_instances

# Build time spine data set with just the agg_time_dimension instance needed for the join.
time_spine_alias = self._next_unique_table_alias()
time_spine_dataset = self._make_time_spine_data_set(
agg_time_dimension_instances=(agg_time_dimension_instance_for_join,),
agg_time_dimension_instances=agg_time_dimension_instances,
time_range_constraint=node.time_range_constraint,
time_spine_where_constraints=node.time_spine_filters or (),
)
Expand All @@ -1420,105 +1423,52 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
parent_alias=parent_alias,
)

# Select all instances from the parent data set, EXCEPT agg_time_dimensions.
# The agg_time_dimensions will be selected from the time spine data set.
time_dimensions_to_select_from_parent: Tuple[TimeDimensionInstance, ...] = ()
time_dimensions_to_select_from_time_spine: Tuple[TimeDimensionInstance, ...] = ()
for time_dimension_instance in parent_data_set.instance_set.time_dimension_instances:
if time_dimension_instance in agg_time_dimension_instances:
time_dimensions_to_select_from_time_spine += (time_dimension_instance,)
else:
time_dimensions_to_select_from_parent += (time_dimension_instance,)
parent_instance_set = InstanceSet(
measure_instances=parent_data_set.instance_set.measure_instances,
dimension_instances=parent_data_set.instance_set.dimension_instances,
time_dimension_instances=time_dimensions_to_select_from_parent,
entity_instances=parent_data_set.instance_set.entity_instances,
metric_instances=parent_data_set.instance_set.metric_instances,
metadata_instances=parent_data_set.instance_set.metadata_instances,
)
parent_select_columns = create_simple_select_columns_for_instance_sets(
self._column_association_resolver, OrderedDict({parent_alias: parent_instance_set})
)

original_time_spine_dim_instance = time_spine_dataset.instance_for_time_dimension(
agg_time_dimension_instance_for_join.spec
)
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression.create(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
)
# Remove time spine instances from parent instance set.
time_spine_instances = time_spine_dataset.instance_set
time_spine_specs = time_spine_instances.spec_set
parent_instance_set = parent_data_set.instance_set.transform(FilterElements(exclude_specs=time_spine_specs))

time_spine_select_columns = []
time_spine_dim_instances = []
where_filter: Optional[SqlExpressionNode] = None
# Build select columns
select_columns = create_simple_select_columns_for_instance_sets(
self._column_association_resolver,
OrderedDict({parent_alias: parent_instance_set, time_spine_alias: time_spine_dataset.instance_set}),
)

# If offset_to_grain is used, will need to filter down to rows that match selected granularities.
# Does not apply if one of the granularities selected matches the time spine column granularity.
where_filter: Optional[SqlExpressionNode] = None
need_where_filter = (
node.offset_to_grain
and original_time_spine_dim_instance.spec not in node.requested_agg_time_dimension_specs
and agg_time_dimension_instance_for_join.spec not in node.requested_agg_time_dimension_specs
)

# Add requested granularities (if different from time_spine) and date_parts to time spine column.
for parent_time_dimension_instance in time_dimensions_to_select_from_time_spine:
time_dimension_spec = parent_time_dimension_instance.spec
if (
time_dimension_spec.time_granularity.base_granularity.to_int()
< original_time_spine_dim_instance.spec.time_granularity.base_granularity.to_int()
):
raise RuntimeError(
f"Can't join to time spine for a time dimension with a smaller granularity than that of the time "
f"spine column. Got {time_dimension_spec.time_granularity} for time dimension, "
f"{original_time_spine_dim_instance.spec.time_granularity} for time spine."
)

# Apply grain to time spine select expression, unless grain already matches original time spine column.
should_skip_date_trunc = (
time_dimension_spec.time_granularity == original_time_spine_dim_instance.spec.time_granularity
or time_dimension_spec.time_granularity.is_custom_granularity
)
select_expr: SqlExpressionNode = (
time_spine_column_select_expr
if should_skip_date_trunc
else SqlDateTruncExpression.create(
time_granularity=time_dimension_spec.time_granularity.base_granularity,
arg=time_spine_column_select_expr,
)
if need_where_filter:
join_column_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias,
column_name=agg_time_dimension_instance_for_join.associated_column.column_name,
)
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
if need_where_filter and time_dimension_spec in node.requested_agg_time_dimension_specs:
new_where_filter = SqlComparisonExpression.create(
left_expr=select_expr, comparison=SqlComparison.EQUALS, right_expr=time_spine_column_select_expr
)
where_filter = (
SqlLogicalExpression.create(operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter))
if where_filter
else new_where_filter
)

# Apply date_part to time spine column select expression.
if time_dimension_spec.date_part:
select_expr = SqlExtractExpression.create(date_part=time_dimension_spec.date_part, arg=select_expr)

time_spine_dim_instance = parent_time_dimension_instance.with_new_defined_from(
original_time_spine_dim_instance.defined_from
)
time_spine_dim_instances.append(time_spine_dim_instance)
time_spine_select_columns.append(
SqlSelectColumn(expr=select_expr, column_alias=time_spine_dim_instance.associated_column.column_name)
)
time_spine_instance_set = InstanceSet(time_dimension_instances=tuple(time_spine_dim_instances))
for time_spine_instance in time_spine_instances.as_tuple:
# Filter down to one row per granularity period requested in the group by. Any other granularities
# included here will be filtered out in later nodes so should not be included in where filter.
if need_where_filter and time_spine_instance.spec in node.requested_agg_time_dimension_specs:
column_to_filter_expr = SqlColumnReferenceExpression.from_table_and_column_names(
table_alias=time_spine_alias, column_name=time_spine_instance.associated_column.column_name
)
new_where_filter = SqlComparisonExpression.create(
left_expr=column_to_filter_expr, comparison=SqlComparison.EQUALS, right_expr=join_column_expr
)
where_filter = (
SqlLogicalExpression.create(
operator=SqlLogicalOperator.OR, args=(where_filter, new_where_filter)
)
if where_filter
else new_where_filter
)

return SqlDataSet(
instance_set=InstanceSet.merge([time_spine_instance_set, parent_instance_set]),
instance_set=InstanceSet.merge([time_spine_dataset.instance_set, parent_instance_set]),
sql_select_node=SqlSelectStatementNode.create(
description=node.description,
select_columns=tuple(time_spine_select_columns) + parent_select_columns,
select_columns=select_columns,
from_source=time_spine_dataset.checked_sql_select_node,
from_source_alias=time_spine_alias,
join_descs=(join_description,),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__week AS metric_time__week
subq_2.metric_time__week AS metric_time__week
, subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__quarter AS metric_time__quarter
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.revenue_instance__ds__day AS revenue_instance__ds__day
, subq_2.revenue_instance__ds__month AS revenue_instance__ds__month
subq_2.revenue_instance__ds__month AS revenue_instance__ds__month
, subq_2.revenue_instance__ds__day AS revenue_instance__ds__day
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__month AS metric_time__month
subq_2.metric_time__month AS metric_time__month
, subq_2.metric_time__day AS metric_time__day
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__week AS metric_time__week
subq_2.metric_time__week AS metric_time__week
, subq_2.metric_time__day AS metric_time__day
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__week AS metric_time__week
subq_2.metric_time__week AS metric_time__week
, subq_2.metric_time__day AS metric_time__day
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.revenue_instance__ds__quarter AS revenue_instance__ds__quarter
, subq_2.revenue_instance__ds__year AS revenue_instance__ds__year
subq_2.revenue_instance__ds__year AS revenue_instance__ds__year
, subq_2.metric_time__day AS metric_time__day
, subq_2.revenue_instance__ds__quarter AS revenue_instance__ds__quarter
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__year AS metric_time__year
subq_2.metric_time__year AS metric_time__year
, subq_2.metric_time__day AS metric_time__day
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,41 @@ sql_engine: DuckDB
---
-- Re-aggregate Metric via Group By
SELECT
subq_11.booking__ds__month
, subq_11.metric_time__week
subq_11.metric_time__week
, subq_11.booking__ds__month
, subq_11.every_two_days_bookers_fill_nulls_with_0
FROM (
-- Window Function for Metric Re-aggregation
SELECT
subq_10.booking__ds__month
, subq_10.metric_time__week
subq_10.metric_time__week
, subq_10.booking__ds__month
, FIRST_VALUE(subq_10.every_two_days_bookers_fill_nulls_with_0) OVER (
PARTITION BY
subq_10.booking__ds__month
, subq_10.metric_time__week
subq_10.metric_time__week
, subq_10.booking__ds__month
ORDER BY subq_10.metric_time__day
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS every_two_days_bookers_fill_nulls_with_0
FROM (
-- Compute Metrics via Expressions
SELECT
subq_9.booking__ds__month
subq_9.metric_time__week
, subq_9.booking__ds__month
, subq_9.metric_time__day
, subq_9.metric_time__week
, COALESCE(subq_9.bookers, 0) AS every_two_days_bookers_fill_nulls_with_0
FROM (
-- Join to Time Spine Dataset
SELECT
DATE_TRUNC('month', subq_7.metric_time__day) AS booking__ds__month
subq_7.metric_time__week AS metric_time__week
, subq_7.booking__ds__month AS booking__ds__month
, subq_7.metric_time__day AS metric_time__day
, DATE_TRUNC('week', subq_7.metric_time__day) AS metric_time__week
, subq_6.bookers AS bookers
FROM (
-- Time Spine
SELECT
subq_8.ds AS metric_time__day
DATE_TRUNC('month', subq_8.ds) AS booking__ds__month
, subq_8.ds AS metric_time__day
, DATE_TRUNC('week', subq_8.ds) AS metric_time__week
FROM ***************************.mf_time_spine subq_8
) subq_7
LEFT OUTER JOIN (
Expand All @@ -60,9 +62,9 @@ FROM (
FROM (
-- Join Self Over Time Range
SELECT
subq_2.booking__ds__month AS booking__ds__month
subq_2.metric_time__week AS metric_time__week
, subq_2.booking__ds__month AS booking__ds__month
, subq_2.metric_time__day AS metric_time__day
, subq_2.metric_time__week AS metric_time__week
, subq_1.ds__day AS ds__day
, subq_1.ds__week AS ds__week
, subq_1.ds__month AS ds__month
Expand Down Expand Up @@ -380,6 +382,6 @@ FROM (
) subq_10
) subq_11
GROUP BY
subq_11.booking__ds__month
, subq_11.metric_time__week
subq_11.metric_time__week
, subq_11.booking__ds__month
, subq_11.every_two_days_bookers_fill_nulls_with_0
Loading

0 comments on commit 1de52ca

Please sign in to comment.