Skip to content

Commit

Permalink
feat(ingest/snowflake): apply table name normalization for queries_v2
Browse files Browse the repository at this point in the history
Adds config `table_name_normalization_rules` to snowflake.
The tables identified by these rules should typically be temporary or transient.
This helps in uniform query_id generation for queries using random table names
for each ETL run for tools like DBT, Segment, Fivetran
  • Loading branch information
mayurinehate committed Feb 6, 2025
1 parent ac13f25 commit 40444fe
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 17 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
Expand Down Expand Up @@ -300,6 +301,18 @@ class SnowflakeV2Config(
"to ignore the temporary staging tables created by known ETL tools.",
)

table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field(
default={},
description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. "
"Specify key as regex to match the table name as it appears in query. "
"The value is the normalized table name. "
"Defaults are to set in such a way to normalize the staging tables created by known ETL tools."
"The tables identified by these rules should typically be temporary or transient tables "
"and should not be used directly in other tools. DataHub will not be able to detect cross"
"-platform lineage for such tables.",
# "Only applicable if `use_queries_v2` is enabled.",
)

rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field(
"upstreams_deny_pattern", "temporary_tables_pattern"
)
Expand All @@ -325,6 +338,15 @@ class SnowflakeV2Config(
"Only applicable if `use_queries_v2` is enabled.",
)

@validator("table_name_normalization_rules")
def validate_pattern(cls, pattern):
if isinstance(pattern, re.Pattern): # Already compiled, good
return pattern

Check warning on line 344 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py#L344

Added line #L344 was not covered by tests
try:
return re.compile(pattern) # Attempt compilation
except re.error as e:
raise ValueError(f"Invalid regular expression: {e}")

Check warning on line 348 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py#L348

Added line #L348 was not covered by tests

@validator("convert_urns_to_lowercase")
def validate_convert_urns_to_lowercase(cls, v):
if not v:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def get_known_query_lineage(

known_lineage = KnownQueryLineageInfo(
query_id=get_query_fingerprint(
query.query_text, self.identifiers.platform, fast=True
query.query_text,
self.identifiers.platform,
fast=True,
),
query_text=query.query_text,
downstream=downstream_table_urn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@ class SnowflakeQueriesExtractorConfig(ConfigModel):
"to ignore the temporary staging tables created by known ETL tools.",
)

table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field(
default={},
description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. "
"Specify key as regex to match the table name as it appears in query. "
"The value is the normalized table name. "
"Defaults are to set in such a way to normalize the staging tables created by known ETL tools."
"The tables identified by these rules should typically be temporary or transient tables "
"and should not be used directly in other tools. DataHub will not be able to detect cross"
"-platform lineage for such tables.",
)

local_temp_path: Optional[pathlib.Path] = pydantic.Field(
default=None,
description="Local path to store the audit log.",
Expand All @@ -110,6 +121,15 @@ class SnowflakeQueriesExtractorConfig(ConfigModel):
include_query_usage_statistics: bool = True
include_operations: bool = True

@pydantic.validator("table_name_normalization_rules")
def validate_pattern(cls, pattern):
if isinstance(pattern, re.Pattern): # Already compiled, good
return pattern
try:
return re.compile(pattern) # Attempt compilation
except re.error as e:
raise ValueError(f"Invalid regular expression: {e}")

Check warning on line 131 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L126-L131

Added lines #L126 - L131 were not covered by tests


class SnowflakeQueriesSourceConfig(
SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig
Expand Down Expand Up @@ -435,7 +455,10 @@ def _parse_audit_log_row(
default_db=res["default_db"],
default_schema=res["default_schema"],
query_hash=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
res["query_text"],
self.identifiers.platform,
fast=True,
table_name_normalization_rules=self.config.table_name_normalization_rules,
),
)

Expand Down Expand Up @@ -514,7 +537,10 @@ def _parse_audit_log_row(
# job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint
# here
query_id=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
res["query_text"],
self.identifiers.platform,
fast=True,
table_name_normalization_rules=self.config.table_name_normalization_rules,
),
query_text=res["query_text"],
upstreams=upstreams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
config=SnowflakeQueriesExtractorConfig(
window=self.config,
temporary_tables_pattern=self.config.temporary_tables_pattern,
table_name_normalization_rules=self.config.table_name_normalization_rules,
include_lineage=self.config.include_table_lineage,
include_usage_statistics=self.config.include_usage_stats,
include_operations=self.config.include_operational_stats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def __init__(
is_allowed_table: Optional[Callable[[str], bool]] = None,
format_queries: bool = True,
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
table_name_normalization_rules: Dict[str, str] = {},
) -> None:
self.platform = DataPlatformUrn(platform)
self.platform_instance = platform_instance
Expand Down Expand Up @@ -495,6 +496,8 @@ def __init__(
self._tool_meta_extractor = ToolMetaExtractor.create(graph)
self.report.tool_meta_report = self._tool_meta_extractor.report

self.table_name_normalization_rules = table_name_normalization_rules

def close(self) -> None:
# Compute stats once before closing connections
self.report.compute_stats()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ def _sqlglot_lineage_inner(
query_type, query_type_props = get_query_type_of_sql(
original_statement, dialect=dialect
)
# TODO: support table name normalization rules for non-fast fingerprinting
query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug(
original_statement, dialect
)
Expand Down
50 changes: 37 additions & 13 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,26 @@ def _expression_to_string(
r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", re.IGNORECASE
): r"\1abcdefgh",
# Date-suffixed table names (e.g. _20210101)
re.compile(r"\b(\w+)(19|20)\d{4}\b"): r"\1YYYYMM",
re.compile(r"\b(\w+)(19|20)\d{6}\b"): r"\1YYYYMMDD",
re.compile(r"\b(\w+)(19|20)\d{8}\b"): r"\1YYYYMMDDHH",
re.compile(r"\b(\w+)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM",
re.compile(r"\b(\w+_?)(19|20)\d{4}\b"): r"\1YYYYMM",
re.compile(r"\b(\w+_?)(19|20)\d{6}\b"): r"\1YYYYMMDD",
re.compile(r"\b(\w+_?)(19|20)\d{8}\b"): r"\1YYYYMMDDHH",
re.compile(r"\b(\w+_?)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM",
re.compile(r"\b(\w+_?)(19|20)\d{12}\b"): r"\1YYYYMMDDHHMMSS",
re.compile(r"\b(\w+_?)(19|20)\d{18}\b"): r"\1YYYYMMDDHHMMSSffffff",
}


def generalize_query_fast(
expression: sqlglot.exp.ExpOrStr,
dialect: DialectOrStr,
change_table_names: bool = False,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
) -> str:
"""Variant of `generalize_query` that only does basic normalization.
Args:
expression: The SQL query to generalize.
dialect: The SQL dialect to use.
change_table_names: If True, replace table names with placeholders. Note
table_name_normalization_rules: If Set, replace table names with placeholders. Note
that this should only be used for query filtering purposes, as it
violates the general assumption that the queries with the same fingerprint
have the same lineage/usage/etc.
Expand All @@ -189,15 +191,19 @@ def generalize_query_fast(

REGEX_REPLACEMENTS = {
**_BASIC_NORMALIZATION_RULES,
**(_TABLE_NAME_NORMALIZATION_RULES if change_table_names else {}),
**(table_name_normalization_rules if table_name_normalization_rules else {}),
}

for pattern, replacement in REGEX_REPLACEMENTS.items():
query_text = pattern.sub(replacement, query_text)
return query_text


def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> str:
def generalize_query(
expression: sqlglot.exp.ExpOrStr,
dialect: DialectOrStr,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
) -> str:
"""
Generalize/normalize a SQL query.
Expand All @@ -222,6 +228,7 @@ def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) ->
# https://tobikodata.com/are_these_sql_queries_the_same.html
# which is used to determine if queries are functionally equivalent.

# TODO: apply table name normalization rules here
dialect = get_dialect(dialect)
expression = sqlglot.maybe_parse(expression, dialect=dialect)

Expand Down Expand Up @@ -260,14 +267,23 @@ def generate_hash(text: str) -> str:


def get_query_fingerprint_debug(
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False
expression: sqlglot.exp.ExpOrStr,
platform: DialectOrStr,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
fast: bool = False,
) -> Tuple[str, Optional[str]]:
try:
if not fast:
dialect = get_dialect(platform)
expression_sql = generalize_query(expression, dialect=dialect)
expression_sql = generalize_query(
expression, dialect, table_name_normalization_rules
)
else:
expression_sql = generalize_query_fast(expression, dialect=platform)
expression_sql = generalize_query_fast(
expression,
dialect=platform,
table_name_normalization_rules=table_name_normalization_rules,
)
except (ValueError, sqlglot.errors.SqlglotError) as e:
if not isinstance(expression, str):
raise
Expand All @@ -284,7 +300,10 @@ def get_query_fingerprint_debug(


def get_query_fingerprint(
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False
expression: sqlglot.exp.ExpOrStr,
platform: DialectOrStr,
fast: bool = False,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
) -> str:
"""Get a fingerprint for a SQL query.
Expand All @@ -306,7 +325,12 @@ def get_query_fingerprint(
The fingerprint for the SQL query.
"""

return get_query_fingerprint_debug(expression, platform, fast=fast)[0]
return get_query_fingerprint_debug(
expression,
platform,
fast=fast,
table_name_normalization_rules=(table_name_normalization_rules or {}),
)[0]


@functools.lru_cache(maxsize=FORMAT_QUERY_CACHE_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
from datahub.sql_parsing.sqlglot_utils import (
_TABLE_NAME_NORMALIZATION_RULES,
generalize_query,
generalize_query_fast,
get_dialect,
Expand Down Expand Up @@ -173,7 +174,11 @@ def test_query_generalization(
assert generalize_query(query, dialect=dialect) == expected
if mode in {QueryGeneralizationTestMode.FAST, QueryGeneralizationTestMode.BOTH}:
assert (
generalize_query_fast(query, dialect=dialect, change_table_names=True)
generalize_query_fast(
query,
dialect=dialect,
table_name_normalization_rules=_TABLE_NAME_NORMALIZATION_RULES,
)
== expected
)

Expand Down

0 comments on commit 40444fe

Please sign in to comment.