From 40444fe57e0baa54c536ece4c1506e781334955b Mon Sep 17 00:00:00 2001 From: Mayuri N Date: Thu, 6 Feb 2025 16:32:49 +0530 Subject: [PATCH] feat(ingest/snowflake): apply table name normalization for queries_v2 Adds config `table_name_normalization_rules` to snowflake. The tables identified by these rules should typically be temporary or transient. This helps in uniform query_id generation for queries using random table names for each ETL run for tools like DBT, Segment, Fivetran --- .../source/snowflake/snowflake_config.py | 22 ++++++++ .../source/snowflake/snowflake_lineage_v2.py | 4 +- .../source/snowflake/snowflake_queries.py | 30 ++++++++++- .../source/snowflake/snowflake_v2.py | 1 + .../sql_parsing/sql_parsing_aggregator.py | 3 ++ .../datahub/sql_parsing/sqlglot_lineage.py | 1 + .../src/datahub/sql_parsing/sqlglot_utils.py | 50 ++++++++++++++----- .../unit/sql_parsing/test_sqlglot_utils.py | 7 ++- 8 files changed, 101 insertions(+), 17 deletions(-) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 94ce51c031b8a0..2a21553c78a5ec 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -1,4 +1,5 @@ import logging +import re from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Set @@ -300,6 +301,18 @@ class SnowflakeV2Config( "to ignore the temporary staging tables created by known ETL tools.", ) + table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field( + default={}, + description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. " + "Specify key as regex to match the table name as it appears in query. " + "The value is the normalized table name. " + "Defaults are to set in such a way to normalize the staging tables created by known ETL tools." + "The tables identified by these rules should typically be temporary or transient tables " + "and should not be used directly in other tools. DataHub will not be able to detect cross" + "-platform lineage for such tables.", + # "Only applicable if `use_queries_v2` is enabled.", + ) + rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field( "upstreams_deny_pattern", "temporary_tables_pattern" ) @@ -325,6 +338,15 @@ class SnowflakeV2Config( "Only applicable if `use_queries_v2` is enabled.", ) + @validator("table_name_normalization_rules") + def validate_pattern(cls, pattern): + if isinstance(pattern, re.Pattern): # Already compiled, good + return pattern + try: + return re.compile(pattern) # Attempt compilation + except re.error as e: + raise ValueError(f"Invalid regular expression: {e}") + @validator("convert_urns_to_lowercase") def validate_convert_urns_to_lowercase(cls, v): if not v: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index e93ecf30171f65..c0b256035e625e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -241,7 +241,9 @@ def get_known_query_lineage( known_lineage = KnownQueryLineageInfo( query_id=get_query_fingerprint( - query.query_text, self.identifiers.platform, fast=True + query.query_text, + self.identifiers.platform, + fast=True, ), query_text=query.query_text, downstream=downstream_table_urn, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index eb015f9d13281f..ac849616129491 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -96,6 +96,17 @@ class SnowflakeQueriesExtractorConfig(ConfigModel): "to ignore the temporary staging tables created by known ETL tools.", ) + table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field( + default={}, + description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. " + "Specify key as regex to match the table name as it appears in query. " + "The value is the normalized table name. " + "Defaults are to set in such a way to normalize the staging tables created by known ETL tools." + "The tables identified by these rules should typically be temporary or transient tables " + "and should not be used directly in other tools. DataHub will not be able to detect cross" + "-platform lineage for such tables.", + ) + local_temp_path: Optional[pathlib.Path] = pydantic.Field( default=None, description="Local path to store the audit log.", @@ -110,6 +121,15 @@ class SnowflakeQueriesExtractorConfig(ConfigModel): include_query_usage_statistics: bool = True include_operations: bool = True + @pydantic.validator("table_name_normalization_rules") + def validate_pattern(cls, pattern): + if isinstance(pattern, re.Pattern): # Already compiled, good + return pattern + try: + return re.compile(pattern) # Attempt compilation + except re.error as e: + raise ValueError(f"Invalid regular expression: {e}") + class SnowflakeQueriesSourceConfig( SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig @@ -435,7 +455,10 @@ def _parse_audit_log_row( default_db=res["default_db"], default_schema=res["default_schema"], query_hash=get_query_fingerprint( - res["query_text"], self.identifiers.platform, fast=True + res["query_text"], + self.identifiers.platform, + fast=True, + table_name_normalization_rules=self.config.table_name_normalization_rules, ), ) @@ -514,7 +537,10 @@ def _parse_audit_log_row( # job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint # here query_id=get_query_fingerprint( - res["query_text"], self.identifiers.platform, fast=True + res["query_text"], + self.identifiers.platform, + fast=True, + table_name_normalization_rules=self.config.table_name_normalization_rules, ), query_text=res["query_text"], upstreams=upstreams, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index daeb839e5f54da..9e8be5ade99bf8 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -573,6 +573,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: config=SnowflakeQueriesExtractorConfig( window=self.config, temporary_tables_pattern=self.config.temporary_tables_pattern, + table_name_normalization_rules=self.config.table_name_normalization_rules, include_lineage=self.config.include_table_lineage, include_usage_statistics=self.config.include_usage_stats, include_operations=self.config.include_operational_stats, diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index 893e89f177094f..b89420d7e73c58 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -337,6 +337,7 @@ def __init__( is_allowed_table: Optional[Callable[[str], bool]] = None, format_queries: bool = True, query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING, + table_name_normalization_rules: Dict[str, str] = {}, ) -> None: self.platform = DataPlatformUrn(platform) self.platform_instance = platform_instance @@ -495,6 +496,8 @@ def __init__( self._tool_meta_extractor = ToolMetaExtractor.create(graph) self.report.tool_meta_report = self._tool_meta_extractor.report + self.table_name_normalization_rules = table_name_normalization_rules + def close(self) -> None: # Compute stats once before closing connections self.report.compute_stats() diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index c825deeccd9592..7b9820c20a635c 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -1085,6 +1085,7 @@ def _sqlglot_lineage_inner( query_type, query_type_props = get_query_type_of_sql( original_statement, dialect=dialect ) + # TODO: support table name normalization rules for non-fast fingerprinting query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug( original_statement, dialect ) diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py index 5b12c64a831666..d81dc31e963db9 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py @@ -157,24 +157,26 @@ def _expression_to_string( r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", re.IGNORECASE ): r"\1abcdefgh", # Date-suffixed table names (e.g. _20210101) - re.compile(r"\b(\w+)(19|20)\d{4}\b"): r"\1YYYYMM", - re.compile(r"\b(\w+)(19|20)\d{6}\b"): r"\1YYYYMMDD", - re.compile(r"\b(\w+)(19|20)\d{8}\b"): r"\1YYYYMMDDHH", - re.compile(r"\b(\w+)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM", + re.compile(r"\b(\w+_?)(19|20)\d{4}\b"): r"\1YYYYMM", + re.compile(r"\b(\w+_?)(19|20)\d{6}\b"): r"\1YYYYMMDD", + re.compile(r"\b(\w+_?)(19|20)\d{8}\b"): r"\1YYYYMMDDHH", + re.compile(r"\b(\w+_?)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM", + re.compile(r"\b(\w+_?)(19|20)\d{12}\b"): r"\1YYYYMMDDHHMMSS", + re.compile(r"\b(\w+_?)(19|20)\d{18}\b"): r"\1YYYYMMDDHHMMSSffffff", } def generalize_query_fast( expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr, - change_table_names: bool = False, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, ) -> str: """Variant of `generalize_query` that only does basic normalization. Args: expression: The SQL query to generalize. dialect: The SQL dialect to use. - change_table_names: If True, replace table names with placeholders. Note + table_name_normalization_rules: If Set, replace table names with placeholders. Note that this should only be used for query filtering purposes, as it violates the general assumption that the queries with the same fingerprint have the same lineage/usage/etc. @@ -189,7 +191,7 @@ def generalize_query_fast( REGEX_REPLACEMENTS = { **_BASIC_NORMALIZATION_RULES, - **(_TABLE_NAME_NORMALIZATION_RULES if change_table_names else {}), + **(table_name_normalization_rules if table_name_normalization_rules else {}), } for pattern, replacement in REGEX_REPLACEMENTS.items(): @@ -197,7 +199,11 @@ def generalize_query_fast( return query_text -def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> str: +def generalize_query( + expression: sqlglot.exp.ExpOrStr, + dialect: DialectOrStr, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, +) -> str: """ Generalize/normalize a SQL query. @@ -222,6 +228,7 @@ def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> # https://tobikodata.com/are_these_sql_queries_the_same.html # which is used to determine if queries are functionally equivalent. + # TODO: apply table name normalization rules here dialect = get_dialect(dialect) expression = sqlglot.maybe_parse(expression, dialect=dialect) @@ -260,14 +267,23 @@ def generate_hash(text: str) -> str: def get_query_fingerprint_debug( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, + fast: bool = False, ) -> Tuple[str, Optional[str]]: try: if not fast: dialect = get_dialect(platform) - expression_sql = generalize_query(expression, dialect=dialect) + expression_sql = generalize_query( + expression, dialect, table_name_normalization_rules + ) else: - expression_sql = generalize_query_fast(expression, dialect=platform) + expression_sql = generalize_query_fast( + expression, + dialect=platform, + table_name_normalization_rules=table_name_normalization_rules, + ) except (ValueError, sqlglot.errors.SqlglotError) as e: if not isinstance(expression, str): raise @@ -284,7 +300,10 @@ def get_query_fingerprint_debug( def get_query_fingerprint( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + fast: bool = False, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, ) -> str: """Get a fingerprint for a SQL query. @@ -306,7 +325,12 @@ def get_query_fingerprint( The fingerprint for the SQL query. """ - return get_query_fingerprint_debug(expression, platform, fast=fast)[0] + return get_query_fingerprint_debug( + expression, + platform, + fast=fast, + table_name_normalization_rules=(table_name_normalization_rules or {}), + )[0] @functools.lru_cache(maxsize=FORMAT_QUERY_CACHE_SIZE) diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py index c3c3a4a15d915b..3070bd095c8f26 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py @@ -8,6 +8,7 @@ from datahub.sql_parsing.sql_parsing_common import QueryType from datahub.sql_parsing.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT from datahub.sql_parsing.sqlglot_utils import ( + _TABLE_NAME_NORMALIZATION_RULES, generalize_query, generalize_query_fast, get_dialect, @@ -173,7 +174,11 @@ def test_query_generalization( assert generalize_query(query, dialect=dialect) == expected if mode in {QueryGeneralizationTestMode.FAST, QueryGeneralizationTestMode.BOTH}: assert ( - generalize_query_fast(query, dialect=dialect, change_table_names=True) + generalize_query_fast( + query, + dialect=dialect, + table_name_normalization_rules=_TABLE_NAME_NORMALIZATION_RULES, + ) == expected )