Skip to content

Commit 40444fe

Browse files
committed
feat(ingest/snowflake): apply table name normalization for queries_v2
Adds config `table_name_normalization_rules` to snowflake. The tables identified by these rules should typically be temporary or transient. This helps in uniform query_id generation for queries using random table names for each ETL run for tools like DBT, Segment, Fivetran
1 parent ac13f25 commit 40444fe

File tree

8 files changed

+101
-17
lines changed

8 files changed

+101
-17
lines changed

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import re
23
from collections import defaultdict
34
from dataclasses import dataclass
45
from typing import Dict, List, Optional, Set
@@ -300,6 +301,18 @@ class SnowflakeV2Config(
300301
"to ignore the temporary staging tables created by known ETL tools.",
301302
)
302303

304+
table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field(
305+
default={},
306+
description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. "
307+
"Specify key as regex to match the table name as it appears in query. "
308+
"The value is the normalized table name. "
309+
"Defaults are to set in such a way to normalize the staging tables created by known ETL tools."
310+
"The tables identified by these rules should typically be temporary or transient tables "
311+
"and should not be used directly in other tools. DataHub will not be able to detect cross"
312+
"-platform lineage for such tables.",
313+
# "Only applicable if `use_queries_v2` is enabled.",
314+
)
315+
303316
rename_upstreams_deny_pattern_to_temporary_table_pattern = pydantic_renamed_field(
304317
"upstreams_deny_pattern", "temporary_tables_pattern"
305318
)
@@ -325,6 +338,15 @@ class SnowflakeV2Config(
325338
"Only applicable if `use_queries_v2` is enabled.",
326339
)
327340

341+
@validator("table_name_normalization_rules")
342+
def validate_pattern(cls, pattern):
343+
if isinstance(pattern, re.Pattern): # Already compiled, good
344+
return pattern
345+
try:
346+
return re.compile(pattern) # Attempt compilation
347+
except re.error as e:
348+
raise ValueError(f"Invalid regular expression: {e}")
349+
328350
@validator("convert_urns_to_lowercase")
329351
def validate_convert_urns_to_lowercase(cls, v):
330352
if not v:

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ def get_known_query_lineage(
241241

242242
known_lineage = KnownQueryLineageInfo(
243243
query_id=get_query_fingerprint(
244-
query.query_text, self.identifiers.platform, fast=True
244+
query.query_text,
245+
self.identifiers.platform,
246+
fast=True,
245247
),
246248
query_text=query.query_text,
247249
downstream=downstream_table_urn,

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_queries.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ class SnowflakeQueriesExtractorConfig(ConfigModel):
9696
"to ignore the temporary staging tables created by known ETL tools.",
9797
)
9898

99+
table_name_normalization_rules: Dict[re.Pattern, str] = pydantic.Field(
100+
default={},
101+
description="[Advanced] Regex patterns for table names to normalize in lineage ingestion. "
102+
"Specify key as regex to match the table name as it appears in query. "
103+
"The value is the normalized table name. "
104+
"Defaults are to set in such a way to normalize the staging tables created by known ETL tools."
105+
"The tables identified by these rules should typically be temporary or transient tables "
106+
"and should not be used directly in other tools. DataHub will not be able to detect cross"
107+
"-platform lineage for such tables.",
108+
)
109+
99110
local_temp_path: Optional[pathlib.Path] = pydantic.Field(
100111
default=None,
101112
description="Local path to store the audit log.",
@@ -110,6 +121,15 @@ class SnowflakeQueriesExtractorConfig(ConfigModel):
110121
include_query_usage_statistics: bool = True
111122
include_operations: bool = True
112123

124+
@pydantic.validator("table_name_normalization_rules")
125+
def validate_pattern(cls, pattern):
126+
if isinstance(pattern, re.Pattern): # Already compiled, good
127+
return pattern
128+
try:
129+
return re.compile(pattern) # Attempt compilation
130+
except re.error as e:
131+
raise ValueError(f"Invalid regular expression: {e}")
132+
113133

114134
class SnowflakeQueriesSourceConfig(
115135
SnowflakeQueriesExtractorConfig, SnowflakeIdentifierConfig, SnowflakeFilterConfig
@@ -435,7 +455,10 @@ def _parse_audit_log_row(
435455
default_db=res["default_db"],
436456
default_schema=res["default_schema"],
437457
query_hash=get_query_fingerprint(
438-
res["query_text"], self.identifiers.platform, fast=True
458+
res["query_text"],
459+
self.identifiers.platform,
460+
fast=True,
461+
table_name_normalization_rules=self.config.table_name_normalization_rules,
439462
),
440463
)
441464

@@ -514,7 +537,10 @@ def _parse_audit_log_row(
514537
# job at eliminating redundant / repetitive queries. As such, we include the fast fingerprint
515538
# here
516539
query_id=get_query_fingerprint(
517-
res["query_text"], self.identifiers.platform, fast=True
540+
res["query_text"],
541+
self.identifiers.platform,
542+
fast=True,
543+
table_name_normalization_rules=self.config.table_name_normalization_rules,
518544
),
519545
query_text=res["query_text"],
520546
upstreams=upstreams,

metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
573573
config=SnowflakeQueriesExtractorConfig(
574574
window=self.config,
575575
temporary_tables_pattern=self.config.temporary_tables_pattern,
576+
table_name_normalization_rules=self.config.table_name_normalization_rules,
576577
include_lineage=self.config.include_table_lineage,
577578
include_usage_statistics=self.config.include_usage_stats,
578579
include_operations=self.config.include_operational_stats,

metadata-ingestion/src/datahub/sql_parsing/sql_parsing_aggregator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ def __init__(
337337
is_allowed_table: Optional[Callable[[str], bool]] = None,
338338
format_queries: bool = True,
339339
query_log: QueryLogSetting = _DEFAULT_QUERY_LOG_SETTING,
340+
table_name_normalization_rules: Dict[str, str] = {},
340341
) -> None:
341342
self.platform = DataPlatformUrn(platform)
342343
self.platform_instance = platform_instance
@@ -495,6 +496,8 @@ def __init__(
495496
self._tool_meta_extractor = ToolMetaExtractor.create(graph)
496497
self.report.tool_meta_report = self._tool_meta_extractor.report
497498

499+
self.table_name_normalization_rules = table_name_normalization_rules
500+
498501
def close(self) -> None:
499502
# Compute stats once before closing connections
500503
self.report.compute_stats()

metadata-ingestion/src/datahub/sql_parsing/sqlglot_lineage.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,6 +1085,7 @@ def _sqlglot_lineage_inner(
10851085
query_type, query_type_props = get_query_type_of_sql(
10861086
original_statement, dialect=dialect
10871087
)
1088+
# TODO: support table name normalization rules for non-fast fingerprinting
10881089
query_fingerprint, debug_info.generalized_statement = get_query_fingerprint_debug(
10891090
original_statement, dialect
10901091
)

metadata-ingestion/src/datahub/sql_parsing/sqlglot_utils.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -157,24 +157,26 @@ def _expression_to_string(
157157
r"\b(ge_tmp_|ge_temp_|gx_temp_)[0-9a-f]{8}\b", re.IGNORECASE
158158
): r"\1abcdefgh",
159159
# Date-suffixed table names (e.g. _20210101)
160-
re.compile(r"\b(\w+)(19|20)\d{4}\b"): r"\1YYYYMM",
161-
re.compile(r"\b(\w+)(19|20)\d{6}\b"): r"\1YYYYMMDD",
162-
re.compile(r"\b(\w+)(19|20)\d{8}\b"): r"\1YYYYMMDDHH",
163-
re.compile(r"\b(\w+)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM",
160+
re.compile(r"\b(\w+_?)(19|20)\d{4}\b"): r"\1YYYYMM",
161+
re.compile(r"\b(\w+_?)(19|20)\d{6}\b"): r"\1YYYYMMDD",
162+
re.compile(r"\b(\w+_?)(19|20)\d{8}\b"): r"\1YYYYMMDDHH",
163+
re.compile(r"\b(\w+_?)(19|20)\d{10}\b"): r"\1YYYYMMDDHHMM",
164+
re.compile(r"\b(\w+_?)(19|20)\d{12}\b"): r"\1YYYYMMDDHHMMSS",
165+
re.compile(r"\b(\w+_?)(19|20)\d{18}\b"): r"\1YYYYMMDDHHMMSSffffff",
164166
}
165167

166168

167169
def generalize_query_fast(
168170
expression: sqlglot.exp.ExpOrStr,
169171
dialect: DialectOrStr,
170-
change_table_names: bool = False,
172+
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
171173
) -> str:
172174
"""Variant of `generalize_query` that only does basic normalization.
173175
174176
Args:
175177
expression: The SQL query to generalize.
176178
dialect: The SQL dialect to use.
177-
change_table_names: If True, replace table names with placeholders. Note
179+
table_name_normalization_rules: If Set, replace table names with placeholders. Note
178180
that this should only be used for query filtering purposes, as it
179181
violates the general assumption that the queries with the same fingerprint
180182
have the same lineage/usage/etc.
@@ -189,15 +191,19 @@ def generalize_query_fast(
189191

190192
REGEX_REPLACEMENTS = {
191193
**_BASIC_NORMALIZATION_RULES,
192-
**(_TABLE_NAME_NORMALIZATION_RULES if change_table_names else {}),
194+
**(table_name_normalization_rules if table_name_normalization_rules else {}),
193195
}
194196

195197
for pattern, replacement in REGEX_REPLACEMENTS.items():
196198
query_text = pattern.sub(replacement, query_text)
197199
return query_text
198200

199201

200-
def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) -> str:
202+
def generalize_query(
203+
expression: sqlglot.exp.ExpOrStr,
204+
dialect: DialectOrStr,
205+
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
206+
) -> str:
201207
"""
202208
Generalize/normalize a SQL query.
203209
@@ -222,6 +228,7 @@ def generalize_query(expression: sqlglot.exp.ExpOrStr, dialect: DialectOrStr) ->
222228
# https://tobikodata.com/are_these_sql_queries_the_same.html
223229
# which is used to determine if queries are functionally equivalent.
224230

231+
# TODO: apply table name normalization rules here
225232
dialect = get_dialect(dialect)
226233
expression = sqlglot.maybe_parse(expression, dialect=dialect)
227234

@@ -260,14 +267,23 @@ def generate_hash(text: str) -> str:
260267

261268

262269
def get_query_fingerprint_debug(
263-
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False
270+
expression: sqlglot.exp.ExpOrStr,
271+
platform: DialectOrStr,
272+
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
273+
fast: bool = False,
264274
) -> Tuple[str, Optional[str]]:
265275
try:
266276
if not fast:
267277
dialect = get_dialect(platform)
268-
expression_sql = generalize_query(expression, dialect=dialect)
278+
expression_sql = generalize_query(
279+
expression, dialect, table_name_normalization_rules
280+
)
269281
else:
270-
expression_sql = generalize_query_fast(expression, dialect=platform)
282+
expression_sql = generalize_query_fast(
283+
expression,
284+
dialect=platform,
285+
table_name_normalization_rules=table_name_normalization_rules,
286+
)
271287
except (ValueError, sqlglot.errors.SqlglotError) as e:
272288
if not isinstance(expression, str):
273289
raise
@@ -284,7 +300,10 @@ def get_query_fingerprint_debug(
284300

285301

286302
def get_query_fingerprint(
287-
expression: sqlglot.exp.ExpOrStr, platform: DialectOrStr, fast: bool = False
303+
expression: sqlglot.exp.ExpOrStr,
304+
platform: DialectOrStr,
305+
fast: bool = False,
306+
table_name_normalization_rules: Optional[Dict[re.Pattern, str]] = None,
288307
) -> str:
289308
"""Get a fingerprint for a SQL query.
290309
@@ -306,7 +325,12 @@ def get_query_fingerprint(
306325
The fingerprint for the SQL query.
307326
"""
308327

309-
return get_query_fingerprint_debug(expression, platform, fast=fast)[0]
328+
return get_query_fingerprint_debug(
329+
expression,
330+
platform,
331+
fast=fast,
332+
table_name_normalization_rules=(table_name_normalization_rules or {}),
333+
)[0]
310334

311335

312336
@functools.lru_cache(maxsize=FORMAT_QUERY_CACHE_SIZE)

metadata-ingestion/tests/unit/sql_parsing/test_sqlglot_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from datahub.sql_parsing.sql_parsing_common import QueryType
99
from datahub.sql_parsing.sqlglot_lineage import _UPDATE_ARGS_NOT_SUPPORTED_BY_SELECT
1010
from datahub.sql_parsing.sqlglot_utils import (
11+
_TABLE_NAME_NORMALIZATION_RULES,
1112
generalize_query,
1213
generalize_query_fast,
1314
get_dialect,
@@ -173,7 +174,11 @@ def test_query_generalization(
173174
assert generalize_query(query, dialect=dialect) == expected
174175
if mode in {QueryGeneralizationTestMode.FAST, QueryGeneralizationTestMode.BOTH}:
175176
assert (
176-
generalize_query_fast(query, dialect=dialect, change_table_names=True)
177+
generalize_query_fast(
178+
query,
179+
dialect=dialect,
180+
table_name_normalization_rules=_TABLE_NAME_NORMALIZATION_RULES,
181+
)
177182
== expected
178183
)
179184

0 commit comments

Comments
 (0)