|
10 | 10 | SqlColumnReferenceExpression, |
11 | 11 | SqlComparison, |
12 | 12 | SqlComparisonExpression, |
| 13 | + SqlStringExpression, |
13 | 14 | ) |
14 | 15 | from metricflow_semantics.sql.sql_join_type import SqlJoinType |
15 | 16 | from metricflow_semantics.sql.sql_table import SqlTable |
@@ -464,3 +465,111 @@ def test_common_cte_aliases_in_nested_query( |
464 | 465 | """ |
465 | 466 | ), |
466 | 467 | ) |
| 468 | + |
| 469 | + |
| 470 | +def test_string_expression( |
| 471 | + request: FixtureRequest, |
| 472 | + mf_test_configuration: MetricFlowTestConfiguration, |
| 473 | + column_pruner: SqlColumnPrunerOptimizer, |
| 474 | + sql_plan_renderer: DefaultSqlPlanRenderer, |
| 475 | +) -> None: |
| 476 | + """Test a string expression that references a column in the cte.""" |
| 477 | + select_statement = SqlSelectStatementNode.create( |
| 478 | + description="Top-level SELECT", |
| 479 | + select_columns=( |
| 480 | + SqlSelectColumn( |
| 481 | + expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)), |
| 482 | + column_alias="top_level__col_0", |
| 483 | + ), |
| 484 | + ), |
| 485 | + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), |
| 486 | + from_source_alias="cte_source_0_alias", |
| 487 | + cte_sources=( |
| 488 | + SqlCteNode.create( |
| 489 | + cte_alias="cte_source_0", |
| 490 | + select_statement=SqlSelectStatementNode.create( |
| 491 | + description="CTE source 0", |
| 492 | + select_columns=( |
| 493 | + SqlSelectColumn( |
| 494 | + expr=SqlColumnReferenceExpression.create( |
| 495 | + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") |
| 496 | + ), |
| 497 | + column_alias="cte_source_0__col_0", |
| 498 | + ), |
| 499 | + SqlSelectColumn( |
| 500 | + expr=SqlColumnReferenceExpression.create( |
| 501 | + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") |
| 502 | + ), |
| 503 | + column_alias="cte_source_0__col_1", |
| 504 | + ), |
| 505 | + ), |
| 506 | + from_source=SqlTableNode.create( |
| 507 | + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") |
| 508 | + ), |
| 509 | + from_source_alias="test_table_alias", |
| 510 | + ), |
| 511 | + ), |
| 512 | + ), |
| 513 | + ) |
| 514 | + assert_optimizer_result_snapshot_equal( |
| 515 | + request=request, |
| 516 | + mf_test_configuration=mf_test_configuration, |
| 517 | + optimizer=column_pruner, |
| 518 | + sql_plan_renderer=sql_plan_renderer, |
| 519 | + select_statement=select_statement, |
| 520 | + expectation_description="`cte_source_0__col_01` should be retained in the CTE.", |
| 521 | + ) |
| 522 | + |
| 523 | + |
| 524 | +def test_column_reference_expression( |
| 525 | + request: FixtureRequest, |
| 526 | + mf_test_configuration: MetricFlowTestConfiguration, |
| 527 | + column_pruner: SqlColumnPrunerOptimizer, |
| 528 | + sql_plan_renderer: DefaultSqlPlanRenderer, |
| 529 | +) -> None: |
| 530 | + """Test a column reference expression that does not specify a table alias.""" |
| 531 | + select_statement = SqlSelectStatementNode.create( |
| 532 | + description="Top-level SELECT", |
| 533 | + select_columns=( |
| 534 | + SqlSelectColumn( |
| 535 | + expr=SqlStringExpression.create(sql_expr="cte_source_0__col_0", used_columns=("cte_source_0__col_0",)), |
| 536 | + column_alias="top_level__col_0", |
| 537 | + ), |
| 538 | + ), |
| 539 | + from_source=SqlTableNode.create(sql_table=SqlTable(schema_name=None, table_name="cte_source_0")), |
| 540 | + from_source_alias="cte_source_0_alias", |
| 541 | + cte_sources=( |
| 542 | + SqlCteNode.create( |
| 543 | + cte_alias="cte_source_0", |
| 544 | + select_statement=SqlSelectStatementNode.create( |
| 545 | + description="CTE source 0", |
| 546 | + select_columns=( |
| 547 | + SqlSelectColumn( |
| 548 | + expr=SqlColumnReferenceExpression.create( |
| 549 | + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") |
| 550 | + ), |
| 551 | + column_alias="cte_source_0__col_0", |
| 552 | + ), |
| 553 | + SqlSelectColumn( |
| 554 | + expr=SqlColumnReferenceExpression.create( |
| 555 | + col_ref=SqlColumnReference(table_alias="test_table_alias", column_name="col_0") |
| 556 | + ), |
| 557 | + column_alias="cte_source_0__col_1", |
| 558 | + ), |
| 559 | + ), |
| 560 | + from_source=SqlTableNode.create( |
| 561 | + sql_table=SqlTable(schema_name="test_schema", table_name="test_table") |
| 562 | + ), |
| 563 | + from_source_alias="test_table_alias", |
| 564 | + ), |
| 565 | + ), |
| 566 | + ), |
| 567 | + ) |
| 568 | + assert_optimizer_result_snapshot_equal( |
| 569 | + request=request, |
| 570 | + mf_test_configuration=mf_test_configuration, |
| 571 | + optimizer=column_pruner, |
| 572 | + sql_plan_renderer=sql_plan_renderer, |
| 573 | + select_statement=select_statement, |
| 574 | + expectation_description="`cte_source_0__col_01` should be retained in the CTE.", |
| 575 | + ) |
0 commit comments