diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index 94ce51c031b8a..2a21553c78a5e 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -1,4 +1,5 @@ import logging +import re from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Set @@ -300,6 +301,18 @@ class SnowflakeV2Config( "to ignore the temporary staging tables created by known ETL tools.", ) + table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field( + default={}, + description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. " + "Specify key as regex to match the table name as it appears in query. " + "The value is the normalized table name. " + "Defaults are to set in such a way to normalize the staging tables created by known ETL tools." + "The tables identified by these rules should typically be temporary or transient tables " + "and should not be used directly in other tools. DataHub will not be able to detect cross" + "-platform lineage for such tables.", + # "Only applicable if `use_queries_v2` is enabled.", + ) + rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field( "upstreams_deny_pattern", "temporary_tables_pattern" ) @@ -325,6 +338,15 @@ class SnowflakeV2Config( "Only applicable if `use_queries_v2` is enabled.", ) + @validator("table_name_normalization_rules") + def validate_pattern(cls, pattern): + if isinstance(pattern, re.Pattern): # Already compiled, good + return pattern + try: + return re.compile(pattern) # Attempt compilation + except re.error as e: + raise ValueError(f"Invalid regular expression: {e}") + @validator("convert_urns_to_lowercase") def validate_convert_urns_to_lowercase(cls, v): if not v: diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index e93ecf30171f6..c0b256035e625 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -241,7 +241,9 @@ def get_known_query_lineage( known_lineage = KnownQueryLineageInfo( query_id=get_query_fingerprint( - query.query_text, self.identifiers.platform, fast=True + query.query_text, + self.identifiers.platform, + fast=True, ), query_text=query.query_text, downstream=downstream_table_urn, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py index eb015f9d13281..ac84961612949 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py @@ -96,6 +96,17 @@ class SnowflakeQueriesExtractorConfig(ConfigModel): "to ignore the temporary staging tables created by known ETL tools.", ) + table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field( + default={}, + description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. " + "Specify key as regex to match the table name as it appears in query. " + "The value is the normalized table name. " + "Defaults are to set in such a way to normalize the staging tables created by known ETL tools." + "The tables identified by these rules should typically be temporary or transient tables " + "and should not be used directly in other tools. DataHub will not be able to detect cross" + "-platform lineage for such tables.", + ) + local_temp_path: Optional[pathlib.Path] = pydantic.Field( default=None, description="Local path to store the audit log.", @@ -110,6 +121,15 @@ class SnowflakeQueriesExtractorConfig(ConfigModel): include_query_usage_statistics: bool = True include_operations: bool = True + @pydantic.validator("table_name_normalization_rules") + def validate_pattern(cls, pattern): + if isinstance(pattern, re.Pattern): # Already compiled, good + return pattern + try: + return re.compile(pattern) # Attempt compilation + except re.error as e: + raise ValueError(f"Invalid regular expression: {e}") + class SnowflakeQueriesSourceConfig( SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig @@ -435,7 +455,10 @@ def _parse_audit_log_row( default_db=res["default_db"], default_schema=res["default_schema"], query_hash=get_query_fingerprint( - res["query_text"], self.identifiers.platform, fast=True + res["query_text"], + self.identifiers.platform, + fast=True, + table_name_normalization_rules=self.config.table_name_normalization_rules, ), ) @@ -514,7 +537,10 @@ def _parse_audit_log_row( # job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint # here query_id=get_query_fingerprint( - res["query_text"], self.identifiers.platform, fast=True + res["query_text"], + self.identifiers.platform, + fast=True, + table_name_normalization_rules=self.config.table_name_normalization_rules, ), query_text=res["query_text"], upstreams=upstreams, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index daeb839e5f54d..9e8be5ade99bf 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -573,6 +573,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: config=SnowflakeQueriesExtractorConfig( window=self.config, temporary_tables_pattern=self.config.temporary_tables_pattern, + table_name_normalization_rules=self.config.table_name_normalization_rules, include_lineage=self.config.include_table_lineage, include_usage_statistics=self.config.include_usage_stats, include_operations=self.config.include_operational_stats, diff --git a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py index 893e89f177094..b89420d7e73c5 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py @@ -337,6 +337,7 @@ def __init__( is_allowed_table: Optional[Callable[[str], bool]] = None, format_queries: bool = True, query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING, + table_name_normalization_rules: Dict[str, str] = {}, ) -> None: self.platform = DataPlatformUrn(platform) self.platform_instance = platform_instance @@ -495,6 +496,8 @@ def __init__( self._tool_meta_extractor = ToolMetaExtractor.create(graph) self.report.tool_meta_report = self._tool_meta_extractor.report + self.table_name_normalization_rules = table_name_normalization_rules + def close(self) -> None: # Compute stats once before closing connections self.report.compute_stats() diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py index c825deeccd959..7b9820c20a635 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py @@ -1085,6 +1085,7 @@ def _sqlglot_lineage_inner( query_type, query_type_props = get_query_type_of_sql( original_statement, dialect=dialect ) + # TODO: support table name normalization rules for non-fast fingerprinting query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug( original_statement, dialect ) diff --git a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py index 5b12c64a83166..d81dc31e963db 100644 --- a/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py +++ b/metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py @@ -157,24 +157,26 @@ def _expression_to_string( r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", re.IGNORECASE ): r"\1abcdefgh", # Date-suffixed table names (e.g. _20210101) - re.compile(r"\b(\w+)(19|20)\d{4}\b"): r"\1YYYYMM", - re.compile(r"\b(\w+)(19|20)\d{6}\b"): r"\1YYYYMMDD", - re.compile(r"\b(\w+)(19|20)\d{8}\b"): r"\1YYYYMMDDHH", - re.compile(r"\b(\w+)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM", + re.compile(r"\b(\w+_?)(19|20)\d{4}\b"): r"\1YYYYMM", + re.compile(r"\b(\w+_?)(19|20)\d{6}\b"): r"\1YYYYMMDD", + re.compile(r"\b(\w+_?)(19|20)\d{8}\b"): r"\1YYYYMMDDHH", + re.compile(r"\b(\w+_?)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM", + re.compile(r"\b(\w+_?)(19|20)\d{12}\b"): r"\1YYYYMMDDHHMMSS", + re.compile(r"\b(\w+_?)(19|20)\d{18}\b"): r"\1YYYYMMDDHHMMSSffffff", } def generalize_query_fast( expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr, - change_table_names: bool = False, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, ) -> str: """Variant of `generalize_query` that only does basic normalization. Args: expression: The SQL query to generalize. dialect: The SQL dialect to use. - change_table_names: If True, replace table names with placeholders. Note + table_name_normalization_rules: If Set, replace table names with placeholders. Note that this should only be used for query filtering purposes, as it violates the general assumption that the queries with the same fingerprint have the same lineage/usage/etc. @@ -189,7 +191,7 @@ def generalize_query_fast( REGEX_REPLACEMENTS = { **_BASIC_NORMALIZATION_RULES, - **(_TABLE_NAME_NORMALIZATION_RULES if change_table_names else {}), + **(table_name_normalization_rules if table_name_normalization_rules else {}), } for pattern, replacement in REGEX_REPLACEMENTS.items(): @@ -197,7 +199,11 @@ def generalize_query_fast( return query_text -def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> str: +def generalize_query( + expression: sqlglot.exp.ExpOrStr, + dialect: DialectOrStr, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, +) -> str: """ Generalize/normalize a SQL query. @@ -222,6 +228,7 @@ def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> # https://tobikodata.com/are_these_sql_queries_the_same.html # which is used to determine if queries are functionally equivalent. + # TODO: apply table name normalization rules here dialect = get_dialect(dialect) expression = sqlglot.maybe_parse(expression, dialect=dialect) @@ -260,14 +267,23 @@ def generate_hash(text: str) -> str: def get_query_fingerprint_debug( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, + fast: bool = False, ) -> Tuple[str, Optional[str]]: try: if not fast: dialect = get_dialect(platform) - expression_sql = generalize_query(expression, dialect=dialect) + expression_sql = generalize_query( + expression, dialect, table_name_normalization_rules + ) else: - expression_sql = generalize_query_fast(expression, dialect=platform) + expression_sql = generalize_query_fast( + expression, + dialect=platform, + table_name_normalization_rules=table_name_normalization_rules, + ) except (ValueError, sqlglot.errors.SqlglotError) as e: if not isinstance(expression, str): raise @@ -284,7 +300,10 @@ def get_query_fingerprint_debug( def get_query_fingerprint( - expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False + expression: sqlglot.exp.ExpOrStr, + platform: DialectOrStr, + fast: bool = False, + table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None, ) -> str: """Get a fingerprint for a SQL query. @@ -306,7 +325,12 @@ def get_query_fingerprint( The fingerprint for the SQL query. """ - return get_query_fingerprint_debug(expression, platform, fast=fast)[0] + return get_query_fingerprint_debug( + expression, + platform, + fast=fast, + table_name_normalization_rules=(table_name_normalization_rules or {}), + )[0] @functools.lru_cache(maxsize=FORMAT_QUERY_CACHE_SIZE) diff --git a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py index c3c3a4a15d915..3070bd095c8f2 100644 --- a/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py +++ b/metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py @@ -8,6 +8,7 @@ from datahub.sql_parsing.sql_parsing_common import QueryType from datahub.sql_parsing.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT from datahub.sql_parsing.sqlglot_utils import ( + _TABLE_NAME_NORMALIZATION_RULES, generalize_query, generalize_query_fast, get_dialect, @@ -173,7 +174,11 @@ def test_query_generalization( assert generalize_query(query, dialect=dialect) == expected if mode in {QueryGeneralizationTestMode.FAST, QueryGeneralizationTestMode.BOTH}: assert ( - generalize_query_fast(query, dialect=dialect, change_table_names=True) + generalize_query_fast( + query, + dialect=dialect, + table_name_normalization_rules=_TABLE_NAME_NORMALIZATION_RULES, + ) == expected )