diff --git a/metricflow/plan_conversion/dataflow_to_sql.py b/metricflow/plan_conversion/dataflow_to_sql.py index 7634fb4b46..f6377d88f8 100644 --- a/metricflow/plan_conversion/dataflow_to_sql.py +++ b/metricflow/plan_conversion/dataflow_to_sql.py @@ -19,14 +19,13 @@ from metricflow_semantics.dag.sequential_id import SequentialIdGenerator from metricflow_semantics.filters.time_constraint import TimeRangeConstraint from metricflow_semantics.instances import ( - DimensionInstance, - EntityInstance, GroupByMetricInstance, InstanceSet, MdoInstance, MetadataInstance, MetricInstance, TimeDimensionInstance, + group_instances_by_type, ) from metricflow_semantics.mf_logging.formatting import indent from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat @@ -496,7 +495,7 @@ def build_select_column( {from_data_set_alias: from_data_set_output_instance_set} ) - # Build join description, instance set, and select columns for each join target. + # Build SQL join description, instance set, and select columns for each join target. output_instance_set = from_data_set_output_instance_set select_columns: Tuple[SqlSelectColumn, ...] = () sql_join_descs: List[SqlJoinDescription] = [] @@ -515,7 +514,6 @@ def build_select_column( sql_join_descs.append(sql_join_desc) if join_on_entity: - # Build instance set that will be available after join. # Remove the linkable instances with the join_on_entity as the leading link as the next step adds the # link. This is to avoid cases where there is a primary entity and a dimension in the data set, and we # create an instance in the next step that has the same entity link. @@ -530,75 +528,20 @@ def build_select_column( # After the right data set is joined, we need to change the links to indicate that they a join was used to # satisfy them. For example, if the right dataset contains the "country" dimension, and "user_id" is the # join_on_entity, then the joined data set should have the "user__country" dimension. - transformed_spec: LinkableInstanceSpec - original_instance: MdoInstance - new_instance: MdoInstance - # Soooo much boilerplate. Figure out how to dedupe. - add this logic to the instances - entity_instances: Tuple[EntityInstance, ...] = () - for original_instance in right_instance_set_filtered.entity_instances: + new_instances: Tuple[MdoInstance, ...] = () + for original_instance in right_instance_set_filtered.as_tuple: # Is this necessary? Does it even work? i.e. diff types here if original_instance.spec == join_on_entity: continue - transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference) - new_instance = EntityInstance( - associated_columns=build_columns(transformed_spec), - defined_from=original_instance.defined_from, - spec=transformed_spec, + new_instance = original_instance.with_entity_prefix( + join_on_entity.reference, column_association_resolver=self._column_association_resolver ) - entity_instances += (new_instance,) select_column = build_select_column( table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance ) + new_instances += (new_instance,) select_columns += (select_column,) - - dimension_instances: Tuple[DimensionInstance, ...] = () - for original_instance in right_instance_set_filtered.dimension_instances: - transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference) - new_instance = DimensionInstance( - associated_columns=build_columns(transformed_spec), - defined_from=original_instance.defined_from, - spec=transformed_spec, - ) - dimension_instances += (new_instance,) - select_column = build_select_column( - table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance - ) - select_columns += (select_column,) - - time_dimension_instances: Tuple[TimeDimensionInstance, ...] = () - for original_instance in right_instance_set_filtered.time_dimension_instances: - transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference) - new_instance = TimeDimensionInstance( - associated_columns=build_columns(transformed_spec), - defined_from=original_instance.defined_from, - spec=transformed_spec, - ) - time_dimension_instances += (new_instance,) - select_column = build_select_column( - table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance - ) - select_columns += (select_column,) - - group_by_metric_instances: Tuple[GroupByMetricInstance, ...] = () - for original_instance in right_instance_set_filtered.group_by_metric_instances: - transformed_spec = original_instance.spec.with_entity_prefix(join_on_entity.reference) - new_instance = GroupByMetricInstance( - associated_columns=build_columns(transformed_spec), - defined_from=original_instance.defined_from, - spec=transformed_spec, - ) - group_by_metric_instances += (new_instance,) - select_column = build_select_column( - table_alias=right_data_set_alias, original_instance=original_instance, new_instance=new_instance - ) - select_columns += (select_column,) - - right_instance_set_after_join = InstanceSet( - dimension_instances=dimension_instances, - entity_instances=entity_instances, - time_dimension_instances=time_dimension_instances, - group_by_metric_instances=group_by_metric_instances, - ) + right_instance_set_after_join = group_instances_by_type(new_instances) else: right_instance_set_after_join = right_data_set.instance_set instances_to_build_simple_select_columns_for[right_data_set_alias] = right_instance_set_after_join @@ -1237,9 +1180,9 @@ def visit_metric_time_dimension_transform_node(self, node: MetricTimeDimensionTr spec=metric_time_dimension_spec, ) ) - output_column_to_input_column[ - metric_time_dimension_column_association.column_name - ] = matching_time_dimension_instance.associated_column.column_name + output_column_to_input_column[metric_time_dimension_column_association.column_name] = ( + matching_time_dimension_instance.associated_column.column_name + ) output_instance_set = InstanceSet( measure_instances=tuple(output_measure_instances), @@ -1476,11 +1419,11 @@ def visit_join_to_time_spine_node(self, node: JoinToTimeSpineNode) -> SqlDataSet f"indicates it may have been configured incorrectly. Expected: {agg_time_dim_for_join_with_base_grain};" f" Got: {[instance.spec for instance in time_spine_dataset.instance_set.time_dimension_instances]}" ) - time_spine_column_select_expr: Union[ - SqlColumnReferenceExpression, SqlDateTruncExpression - ] = SqlColumnReferenceExpression.create( - SqlColumnReference( - table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name + time_spine_column_select_expr: Union[SqlColumnReferenceExpression, SqlDateTruncExpression] = ( + SqlColumnReferenceExpression.create( + SqlColumnReference( + table_alias=time_spine_alias, column_name=original_time_spine_dim_instance.spec.qualified_name + ) ) ) diff --git a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_inner_query_multi_hop__plan0.sql b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_inner_query_multi_hop__plan0.sql index 717e7dc67d..5ce70b6587 100644 --- a/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_inner_query_multi_hop__plan0.sql +++ b/tests_metricflow/snapshots/test_metric_filter_rendering.py/SqlQueryPlan/DuckDB/test_inner_query_multi_hop__plan0.sql @@ -127,8 +127,8 @@ FROM ( FROM ( -- Join Standard Outputs SELECT - subq_12.customer_id__customer_third_hop_id AS account_id__customer_id__customer_third_hop_id - , subq_12.ds_partitioned__day AS account_id__ds_partitioned__day + subq_12.ds_partitioned__day AS account_id__ds_partitioned__day + , subq_12.customer_id__customer_third_hop_id AS account_id__customer_id__customer_third_hop_id , subq_5.ds_partitioned__day AS ds_partitioned__day , subq_5.account_id AS account_id , subq_5.txn_count AS txn_count @@ -264,9 +264,7 @@ FROM ( FROM ( -- Join Standard Outputs SELECT - subq_10.customer_third_hop_id AS customer_id__customer_third_hop_id - , subq_10.customer_third_hop_id__customer_id AS customer_id__customer_third_hop_id__customer_id - , subq_10.country AS customer_id__country + subq_10.country AS customer_id__country , subq_10.customer_third_hop_id__country AS customer_id__customer_third_hop_id__country , subq_10.acquired_ds__day AS customer_id__acquired_ds__day , subq_10.acquired_ds__week AS customer_id__acquired_ds__week @@ -301,6 +299,8 @@ FROM ( , subq_10.metric_time__extract_day AS customer_id__metric_time__extract_day , subq_10.metric_time__extract_dow AS customer_id__metric_time__extract_dow , subq_10.metric_time__extract_doy AS customer_id__metric_time__extract_doy + , subq_10.customer_third_hop_id AS customer_id__customer_third_hop_id + , subq_10.customer_third_hop_id__customer_id AS customer_id__customer_third_hop_id__customer_id , subq_7.ds_partitioned__day AS ds_partitioned__day , subq_7.ds_partitioned__week AS ds_partitioned__week , subq_7.ds_partitioned__month AS ds_partitioned__month