Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/snowflake): apply table name normalization for queries_v2 #12566

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import re
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Set
Expand Down Expand Up @@ -300,6 +301,18 @@
"to ignore the temporary staging tables created by known ETL tools.",
)

table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where are the defaults injected in?

default={},
description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense for us to have table_name_normalization_rules and extra_table_name_normalization_rules - setting the former overrides the defaults, but the latter lets you extend it?

the ruff linter uses a similar pattern which works pretty well https://docs.astral.sh/ruff/settings/#lint_extend-select

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an excellent idea. Have not injected defaults yet. That's exactly the problem I was looking to solve - how to let users extend defaults without re-declaring them, and at the same time, allow users to change defaults if they do not suit them.

"Specify key as regex to match the table name as it appears in query. "
"The value is the normalized table name. "
"Defaults are to set in such a way to normalize the staging tables created by known ETL tools."
"The tables identified by these rules should typically be temporary or transient tables "
"and should not be used directly in other tools. DataHub will not be able to detect cross"
"-platform lineage for such tables.",
# "Only applicable if `use_queries_v2` is enabled.",
)

rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field(
"upstreams_deny_pattern", "temporary_tables_pattern"
)
Expand All @@ -325,6 +338,15 @@
"Only applicable if `use_queries_v2` is enabled.",
)

@validator("table_name_normalization_rules")
def validate_pattern(cls, pattern):
if isinstance(pattern, re.Pattern): # Already compiled, good
return pattern

Check warning on line 344 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py#L344

Added line #L344 was not covered by tests
try:
return re.compile(pattern) # Attempt compilation
except re.error as e:
raise ValueError(f"Invalid regular expression: {e}")

Check warning on line 348 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py#L348

Added line #L348 was not covered by tests

@validator("convert_urns_to_lowercase")
def validate_convert_urns_to_lowercase(cls, v):
if not v:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def get_known_query_lineage(

known_lineage = KnownQueryLineageInfo(
query_id=get_query_fingerprint(
query.query_text, self.identifiers.platform, fast=True
query.query_text,
self.identifiers.platform,
fast=True,
),
query_text=query.query_text,
downstream=downstream_table_urn,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,17 @@
"to ignore the temporary staging tables created by known ETL tools.",
)

table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field(
default={},
description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. "
"Specify key as regex to match the table name as it appears in query. "
"The value is the normalized table name. "
"Defaults are to set in such a way to normalize the staging tables created by known ETL tools."
"The tables identified by these rules should typically be temporary or transient tables "
"and should not be used directly in other tools. DataHub will not be able to detect cross"
"-platform lineage for such tables.",
)

local_temp_path: Optional[pathlib.Path] = pydantic.Field(
default=None,
description="Local path to store the audit log.",
Expand All @@ -110,6 +121,15 @@
include_query_usage_statistics: bool = True
include_operations: bool = True

@pydantic.validator("table_name_normalization_rules")
def validate_pattern(cls, pattern):
if isinstance(pattern, re.Pattern): # Already compiled, good
return pattern
try:
return re.compile(pattern) # Attempt compilation
except re.error as e:
raise ValueError(f"Invalid regular expression: {e}")

Check warning on line 131 in metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

View check run for this annotation

Codecov / codecov/patch

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py#L126-L131

Added lines #L126 - L131 were not covered by tests


class SnowflakeQueriesSourceConfig(
SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig
Expand Down Expand Up @@ -435,7 +455,10 @@
default_db=res["default_db"],
default_schema=res["default_schema"],
query_hash=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
res["query_text"],
self.identifiers.platform,
fast=True,
table_name_normalization_rules=self.config.table_name_normalization_rules,
),
)

Expand Down Expand Up @@ -514,7 +537,10 @@
# job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint
# here
query_id=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
res["query_text"],
self.identifiers.platform,
fast=True,
table_name_normalization_rules=self.config.table_name_normalization_rules,
),
query_text=res["query_text"],
upstreams=upstreams,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
config=SnowflakeQueriesExtractorConfig(
window=self.config,
temporary_tables_pattern=self.config.temporary_tables_pattern,
table_name_normalization_rules=self.config.table_name_normalization_rules,
include_lineage=self.config.include_table_lineage,
include_usage_statistics=self.config.include_usage_stats,
include_operations=self.config.include_operational_stats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def __init__(
is_allowed_table: Optional[Callable[[str], bool]] = None,
format_queries: bool = True,
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
table_name_normalization_rules: Dict[str, str] = {},
) -> None:
self.platform = DataPlatformUrn(platform)
self.platform_instance = platform_instance
Expand Down Expand Up @@ -495,6 +496,8 @@ def __init__(
self._tool_meta_extractor = ToolMetaExtractor.create(graph)
self.report.tool_meta_report = self._tool_meta_extractor.report

self.table_name_normalization_rules = table_name_normalization_rules

def close(self) -> None:
# Compute stats once before closing connections
self.report.compute_stats()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1085,6 +1085,7 @@ def _sqlglot_lineage_inner(
query_type, query_type_props = get_query_type_of_sql(
original_statement, dialect=dialect
)
# TODO: support table name normalization rules for non-fast fingerprinting
query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug(
original_statement, dialect
)
Expand Down
50 changes: 37 additions & 13 deletions metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,24 +157,26 @@ def _expression_to_string(
r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", re.IGNORECASE
): r"\1abcdefgh",
# Date-suffixed table names (e.g. _20210101)
re.compile(r"\b(\w+)(19|20)\d{4}\b"): r"\1YYYYMM",
re.compile(r"\b(\w+)(19|20)\d{6}\b"): r"\1YYYYMMDD",
re.compile(r"\b(\w+)(19|20)\d{8}\b"): r"\1YYYYMMDDHH",
re.compile(r"\b(\w+)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM",
re.compile(r"\b(\w+_?)(19|20)\d{4}\b"): r"\1YYYYMM",
re.compile(r"\b(\w+_?)(19|20)\d{6}\b"): r"\1YYYYMMDD",
re.compile(r"\b(\w+_?)(19|20)\d{8}\b"): r"\1YYYYMMDDHH",
re.compile(r"\b(\w+_?)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM",
re.compile(r"\b(\w+_?)(19|20)\d{12}\b"): r"\1YYYYMMDDHHMMSS",
re.compile(r"\b(\w+_?)(19|20)\d{18}\b"): r"\1YYYYMMDDHHMMSSffffff",
}


def generalize_query_fast(
expression: sqlglot.exp.ExpOrStr,
dialect: DialectOrStr,
change_table_names: bool = False,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
) -> str:
"""Variant of `generalize_query` that only does basic normalization.

Args:
expression: The SQL query to generalize.
dialect: The SQL dialect to use.
change_table_names: If True, replace table names with placeholders. Note
table_name_normalization_rules: If Set, replace table names with placeholders. Note
that this should only be used for query filtering purposes, as it
violates the general assumption that the queries with the same fingerprint
have the same lineage/usage/etc.
Expand All @@ -189,15 +191,19 @@ def generalize_query_fast(

REGEX_REPLACEMENTS = {
**_BASIC_NORMALIZATION_RULES,
**(_TABLE_NAME_NORMALIZATION_RULES if change_table_names else {}),
**(table_name_normalization_rules if table_name_normalization_rules else {}),
}

for pattern, replacement in REGEX_REPLACEMENTS.items():
query_text = pattern.sub(replacement, query_text)
return query_text


def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> str:
def generalize_query(
expression: sqlglot.exp.ExpOrStr,
dialect: DialectOrStr,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
) -> str:
"""
Generalize/normalize a SQL query.

Expand All @@ -222,6 +228,7 @@ def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) ->
# https://tobikodata.com/are_these_sql_queries_the_same.html
# which is used to determine if queries are functionally equivalent.

# TODO: apply table name normalization rules here
dialect = get_dialect(dialect)
expression = sqlglot.maybe_parse(expression, dialect=dialect)

Expand Down Expand Up @@ -260,14 +267,23 @@ def generate_hash(text: str) -> str:


def get_query_fingerprint_debug(
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False
expression: sqlglot.exp.ExpOrStr,
platform: DialectOrStr,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
fast: bool = False,
) -> Tuple[str, Optional[str]]:
try:
if not fast:
dialect = get_dialect(platform)
expression_sql = generalize_query(expression, dialect=dialect)
expression_sql = generalize_query(
expression, dialect, table_name_normalization_rules
)
else:
expression_sql = generalize_query_fast(expression, dialect=platform)
expression_sql = generalize_query_fast(
expression,
dialect=platform,
table_name_normalization_rules=table_name_normalization_rules,
)
except (ValueError, sqlglot.errors.SqlglotError) as e:
if not isinstance(expression, str):
raise
Expand All @@ -284,7 +300,10 @@ def get_query_fingerprint_debug(


def get_query_fingerprint(
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False
expression: sqlglot.exp.ExpOrStr,
platform: DialectOrStr,
fast: bool = False,
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
) -> str:
"""Get a fingerprint for a SQL query.

Expand All @@ -306,7 +325,12 @@ def get_query_fingerprint(
The fingerprint for the SQL query.
"""

return get_query_fingerprint_debug(expression, platform, fast=fast)[0]
return get_query_fingerprint_debug(
expression,
platform,
fast=fast,
table_name_normalization_rules=(table_name_normalization_rules or {}),
)[0]


@functools.lru_cache(maxsize=FORMAT_QUERY_CACHE_SIZE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from datahub.sql_parsing.sql_parsing_common import QueryType
from datahub.sql_parsing.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
from datahub.sql_parsing.sqlglot_utils import (
_TABLE_NAME_NORMALIZATION_RULES,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should remove the _ prefix if this is a public thing that other sources depend on

generalize_query,
generalize_query_fast,
get_dialect,
Expand Down Expand Up @@ -173,7 +174,11 @@ def test_query_generalization(
assert generalize_query(query, dialect=dialect) == expected
if mode in {QueryGeneralizationTestMode.FAST, QueryGeneralizationTestMode.BOTH}:
assert (
generalize_query_fast(query, dialect=dialect, change_table_names=True)
generalize_query_fast(
query,
dialect=dialect,
table_name_normalization_rules=_TABLE_NAME_NORMALIZATION_RULES,
)
== expected
)

Expand Down
Loading