Skip to content

Commit

Permalink
Simplify new logic with helpers - should combo with generate new sele…
Browse files Browse the repository at this point in the history
…ct columns commit
  • Loading branch information
courtneyholcomb committed Nov 2, 2024
1 parent 79b66fe commit 47c3730
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 78 deletions.
89 changes: 16 additions & 73 deletions metricflow/plan_conversion/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@
from metricflow_semantics.dag.sequential_id import SequentialIdGenerator
from metricflow_semantics.filters.time_constraint import TimeRangeConstraint
from metricflow_semantics.instances import (
DimensionInstance,
EntityInstance,
GroupByMetricInstance,
InstanceSet,
MdoInstance,
MetadataInstance,
MetricInstance,
TimeDimensionInstance,
group_instances_by_type,
)
from metricflow_semantics.mf_logging.formatting import indent
from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat
Expand Down Expand Up @@ -496,7 +495,7 @@ def build_select_column(
{from_data_set_alias: from_data_set_output_instance_set}
)

# Build join description, instance set, and select columns for each join target.
# Build SQL join description, instance set, and select columns for each join target.
output_instance_set = from_data_set_output_instance_set
select_columns: Tuple[SqlSelectColumn, ...] = ()
sql_join_descs: List[SqlJoinDescription] = []
Expand All @@ -515,7 +514,6 @@ def build_select_column(
sql_join_descs.append(sql_join_desc)

if join_on_entity:
# Build instance set that will be available after join.
# Remove the linkable instances with the join_on_entity as the leading link as the next step adds the
# link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we
# create an instance in the next step that has the same entity link.
Expand All @@ -530,75 +528,20 @@ def build_select_column(
# After the right data set is joined, we need to change the links to indicate that they a join was used to
# satisfy them. For example, if the right dataset contains the "country" dimension, and "user_id" is the
# join_on_entity, then the joined data set should have the "user__country" dimension.
transformed_spec: LinkableInstanceSpec
original_instance: MdoInstance
new_instance: MdoInstance
# Soooo much boilerplate. Figure out how to dedupe. - add this logic to the instances
entity_instances: Tuple[EntityInstance, ...] = ()
for original_instance in right_instance_set_filtered.entity_instances:
new_instances: Tuple[MdoInstance, ...] = ()
for original_instance in right_instance_set_filtered.as_tuple:
# Is this necessary? Does it even work? i.e. diff types here
if original_instance.spec == join_on_entity:
continue
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = EntityInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
new_instance = original_instance.with_entity_prefix(
join_on_entity.reference, column_association_resolver=self._column_association_resolver
)
entity_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
new_instances += (new_instance,)
select_columns += (select_column,)

dimension_instances: Tuple[DimensionInstance, ...] = ()
for original_instance in right_instance_set_filtered.dimension_instances:
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = DimensionInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
dimension_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

time_dimension_instances: Tuple[TimeDimensionInstance, ...] = ()
for original_instance in right_instance_set_filtered.time_dimension_instances:
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = TimeDimensionInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
time_dimension_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

group_by_metric_instances: Tuple[GroupByMetricInstance, ...] = ()
for original_instance in right_instance_set_filtered.group_by_metric_instances:
transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference)
new_instance = GroupByMetricInstance(
associated_columns=build_columns(transformed_spec),
defined_from=original_instance.defined_from,
spec=transformed_spec,
)
group_by_metric_instances += (new_instance,)
select_column = build_select_column(
table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance
)
select_columns += (select_column,)

right_instance_set_after_join = InstanceSet(
dimension_instances=dimension_instances,
entity_instances=entity_instances,
time_dimension_instances=time_dimension_instances,
group_by_metric_instances=group_by_metric_instances,
)
right_instance_set_after_join = group_instances_by_type(new_instances)
else:
right_instance_set_after_join = right_data_set.instance_set
instances_to_build_simple_select_columns_for[right_data_set_alias] = right_instance_set_after_join
Expand Down Expand Up @@ -1237,9 +1180,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr
spec=metric_time_dimension_spec,
)
)
output_column_to_input_column[
metric_time_dimension_column_association.column_name
] = matching_time_dimension_instance.associated_column.column_name
output_column_to_input_column[metric_time_dimension_column_association.column_name] = (
matching_time_dimension_instance.associated_column.column_name
)

output_instance_set = InstanceSet(
measure_instances=tuple(output_measure_instances),
Expand Down Expand Up @@ -1476,11 +1419,11 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet
f"indicates it may have been configured incorrectly. Expected: {agg_time_dim_for_join_with_base_grain};"
f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}"
)
time_spine_column_select_expr: Union[
SqlColumnReferenceExpression, SqlDateTruncExpression
] = SqlColumnReferenceExpression.create(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
time_spine_column_select_expr: Union[SqlColumnReferenceExpression, SqlDateTruncExpression] = (
SqlColumnReferenceExpression.create(
SqlColumnReference(
table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name
)
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ FROM (
FROM (
-- Join Standard Outputs
SELECT
subq_12.customer_id__customer_third_hop_id AS account_id__customer_id__customer_third_hop_id
, subq_12.ds_partitioned__day AS account_id__ds_partitioned__day
subq_12.ds_partitioned__day AS account_id__ds_partitioned__day
, subq_12.customer_id__customer_third_hop_id AS account_id__customer_id__customer_third_hop_id
, subq_5.ds_partitioned__day AS ds_partitioned__day
, subq_5.account_id AS account_id
, subq_5.txn_count AS txn_count
Expand Down Expand Up @@ -264,9 +264,7 @@ FROM (
FROM (
-- Join Standard Outputs
SELECT
subq_10.customer_third_hop_id AS customer_id__customer_third_hop_id
, subq_10.customer_third_hop_id__customer_id AS customer_id__customer_third_hop_id__customer_id
, subq_10.country AS customer_id__country
subq_10.country AS customer_id__country
, subq_10.customer_third_hop_id__country AS customer_id__customer_third_hop_id__country
, subq_10.acquired_ds__day AS customer_id__acquired_ds__day
, subq_10.acquired_ds__week AS customer_id__acquired_ds__week
Expand Down Expand Up @@ -301,6 +299,8 @@ FROM (
, subq_10.metric_time__extract_day AS customer_id__metric_time__extract_day
, subq_10.metric_time__extract_dow AS customer_id__metric_time__extract_dow
, subq_10.metric_time__extract_doy AS customer_id__metric_time__extract_doy
, subq_10.customer_third_hop_id AS customer_id__customer_third_hop_id
, subq_10.customer_third_hop_id__customer_id AS customer_id__customer_third_hop_id__customer_id
, subq_7.ds_partitioned__day AS ds_partitioned__day
, subq_7.ds_partitioned__week AS ds_partitioned__week
, subq_7.ds_partitioned__month AS ds_partitioned__month
Expand Down

0 comments on commit 47c3730

Please sign in to comment.