Skip to content

Commit

Permalink
feat(ingestion/snowflake):adds streams as a new dataset with lineage …
Browse files Browse the repository at this point in the history
…and properties. (#12318)

Co-authored-by: Mayuri Nehate <[email protected]>
  • Loading branch information
brock-acryl and mayurinehate authored Feb 5, 2025
1 parent 7f6e399 commit ac13f25
Show file tree
Hide file tree
Showing 17 changed files with 3,783 additions and 1,811 deletions.
9 changes: 7 additions & 2 deletions metadata-ingestion/docs/sources/snowflake/snowflake_pre.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ grant operate, usage on warehouse "<your-warehouse>" to role datahub_role;
grant usage on DATABASE "<your-database>" to role datahub_role;
grant usage on all schemas in database "<your-database>" to role datahub_role;
grant usage on future schemas in database "<your-database>" to role datahub_role;
grant select on all streams in database "<your-database>> to role datahub_role;
grant select on future streams in database "<your-database>> to role datahub_role;

// If you are NOT using Snowflake Profiling or Classification feature: Grant references privileges to your tables and views
grant references on all tables in database "<your-database>" to role datahub_role;
Expand Down Expand Up @@ -50,9 +52,12 @@ The details of each granted privilege can be viewed in [snowflake docs](https://
If the warehouse is already running during ingestion or has auto-resume enabled,
this permission is not required.
- `usage` is required for us to run queries using the warehouse
- `usage` on `database` and `schema` are required because without it tables and views inside them are not accessible. If an admin does the required grants on `table` but misses the grants on `schema` or the `database` in which the table/view exists then we will not be able to get metadata for the table/view.
- `usage` on `database` and `schema` are required because without it tables, views, and streams inside them are not accessible. If an admin does the required grants on `table` but misses the grants on `schema` or the `database` in which the table/view/stream exists then we will not be able to get metadata for the table/view/stream.
- If metadata is required only on some schemas then you can grant the usage privilieges only on a particular schema like

```sql
grant usage on schema "<your-database>"."<your-schema>" to role datahub_role;
```
- `select` on `streams` is required in order for stream definitions to be available. This does not allow selecting of the data (not required) unless the underlying dataset has select access as well.
```sql
grant usage on schema "<your-database>"."<your-schema>" to role datahub_role;
```
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class DatasetSubTypes(StrEnum):
SAC_LIVE_DATA_MODEL = "Live Data Model"
NEO4J_NODE = "Neo4j Node"
NEO4J_RELATIONSHIP = "Neo4j Relationship"
SNOWFLAKE_STREAM = "Snowflake Stream"

# TODO: Create separate entity...
NOTEBOOK = "Notebook"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class SnowflakeObjectDomain(StrEnum):
SCHEMA = "schema"
COLUMN = "column"
ICEBERG_TABLE = "iceberg table"
STREAM = "stream"


GENERIC_PERMISSION_ERROR_KEY = "permission-error"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ class SnowflakeFilterConfig(SQLFilterConfig):
)
# table_pattern and view_pattern are inherited from SQLFilterConfig

stream_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description="Regex patterns for streams to filter in ingestion. Note: Defaults to table_pattern if not specified. Specify regex to match the entire view name in database.schema.view format. e.g. to match all views starting with customer in Customer database and public schema, use the regex 'Customer.public.customer.*'",
)

match_fully_qualified_names: bool = Field(
default=False,
description="Whether `schema_pattern` is matched against fully qualified schema name `<catalog>.<schema>`.",
Expand Down Expand Up @@ -274,6 +279,11 @@ class SnowflakeV2Config(
description="List of regex patterns for tags to include in ingestion. Only used if `extract_tags` is enabled.",
)

include_streams: bool = Field(
default=True,
description="If enabled, streams will be ingested as separate entities from tables/views.",
)

structured_property_pattern: AllowDenyPattern = Field(
default=AllowDenyPattern.allow_all(),
description=(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from datahub.sql_parsing.schema_resolver import SchemaResolver
from datahub.sql_parsing.sql_parsing_aggregator import (
KnownLineageMapping,
ObservedQuery,
PreparsedQuery,
SqlAggregatorReport,
SqlParsingAggregator,
Expand Down Expand Up @@ -241,7 +242,13 @@ def get_workunits_internal(
use_cached_audit_log = audit_log_file.exists()

queries: FileBackedList[
Union[KnownLineageMapping, PreparsedQuery, TableRename, TableSwap]
Union[
KnownLineageMapping,
PreparsedQuery,
TableRename,
TableSwap,
ObservedQuery,
]
]
if use_cached_audit_log:
logger.info("Using cached audit log")
Expand All @@ -252,7 +259,13 @@ def get_workunits_internal(

shared_connection = ConnectionWrapper(audit_log_file)
queries = FileBackedList(shared_connection)
entry: Union[KnownLineageMapping, PreparsedQuery, TableRename, TableSwap]
entry: Union[
KnownLineageMapping,
PreparsedQuery,
TableRename,
TableSwap,
ObservedQuery,
]

with self.report.copy_history_fetch_timer:
for entry in self.fetch_copy_history():
Expand Down Expand Up @@ -329,7 +342,7 @@ def fetch_copy_history(self) -> Iterable[KnownLineageMapping]:

def fetch_query_log(
self, users: UsersMapping
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap]]:
) -> Iterable[Union[PreparsedQuery, TableRename, TableSwap, ObservedQuery]]:
query_log_query = _build_enriched_query_log_query(
start_time=self.config.window.start_time,
end_time=self.config.window.end_time,
Expand Down Expand Up @@ -362,7 +375,7 @@ def fetch_query_log(

def _parse_audit_log_row(
self, row: Dict[str, Any], users: UsersMapping
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery]]:
) -> Optional[Union[TableRename, TableSwap, PreparsedQuery, ObservedQuery]]:
json_fields = {
"DIRECT_OBJECTS_ACCESSED",
"OBJECTS_MODIFIED",
Expand Down Expand Up @@ -398,6 +411,34 @@ def _parse_audit_log_row(
pass
else:
return None

user = CorpUserUrn(
self.identifiers.get_user_identifier(
res["user_name"], users.get(res["user_name"])
)
)

# Use direct_objects_accessed instead objects_modified
# objects_modified returns $SYS_VIEW_X with no mapping
has_stream_objects = any(
obj.get("objectDomain") == "Stream" for obj in direct_objects_accessed
)

# If a stream is used, default to query parsing.
if has_stream_objects:
logger.debug("Found matching stream object")
return ObservedQuery(
query=res["query_text"],
session_id=res["session_id"],
timestamp=res["query_start_time"].astimezone(timezone.utc),
user=user,
default_db=res["default_db"],
default_schema=res["default_schema"],
query_hash=get_query_fingerprint(
res["query_text"], self.identifiers.platform, fast=True
),
)

upstreams = []
column_usage = {}

Expand Down Expand Up @@ -460,12 +501,6 @@ def _parse_audit_log_row(
)
)

user = CorpUserUrn(
self.identifiers.get_user_identifier(
res["user_name"], users.get(res["user_name"])
)
)

timestamp: datetime = res["query_start_time"]
timestamp = timestamp.astimezone(timezone.utc)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from datahub.utilities.prefix_batch_builder import PrefixGroup

SHOW_VIEWS_MAX_PAGE_SIZE = 10000
SHOW_STREAM_MAX_PAGE_SIZE = 10000


def create_deny_regex_sql_filter(
Expand Down Expand Up @@ -36,6 +37,7 @@ class SnowflakeQuery:
SnowflakeObjectDomain.VIEW.capitalize(),
SnowflakeObjectDomain.MATERIALIZED_VIEW.capitalize(),
SnowflakeObjectDomain.ICEBERG_TABLE.capitalize(),
SnowflakeObjectDomain.STREAM.capitalize(),
}

ACCESS_HISTORY_TABLE_VIEW_DOMAINS_FILTER = "({})".format(
Expand All @@ -44,7 +46,8 @@ class SnowflakeQuery:
ACCESS_HISTORY_TABLE_DOMAINS_FILTER = (
"("
f"'{SnowflakeObjectDomain.TABLE.capitalize()}',"
f"'{SnowflakeObjectDomain.VIEW.capitalize()}'"
f"'{SnowflakeObjectDomain.VIEW.capitalize()}',"
f"'{SnowflakeObjectDomain.STREAM.capitalize()}',"
")"
)

Expand Down Expand Up @@ -963,3 +966,19 @@ def dmf_assertion_results(start_time_millis: int, end_time_millis: int) -> str:
@staticmethod
def get_all_users() -> str:
return """SELECT name as "NAME", email as "EMAIL" FROM SNOWFLAKE.ACCOUNT_USAGE.USERS"""

@staticmethod
def streams_for_database(
db_name: str,
limit: int = SHOW_STREAM_MAX_PAGE_SIZE,
stream_pagination_marker: Optional[str] = None,
) -> str:
# SHOW STREAMS can return a maximum of 10000 rows.
# https://docs.snowflake.com/en/sql-reference/sql/show-streams#usage-notes
assert limit <= SHOW_STREAM_MAX_PAGE_SIZE

# To work around this, we paginate through the results using the FROM clause.
from_clause = (
f"""FROM '{stream_pagination_marker}'""" if stream_pagination_marker else ""
)
return f"""SHOW STREAMS IN DATABASE {db_name} LIMIT {limit} {from_clause};"""
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class SnowflakeV2Report(
schemas_scanned: int = 0
databases_scanned: int = 0
tags_scanned: int = 0
streams_scanned: int = 0

include_usage_stats: bool = False
include_operational_stats: bool = False
Expand All @@ -113,6 +114,7 @@ class SnowflakeV2Report(
table_lineage_query_secs: float = -1
external_lineage_queries_secs: float = -1
num_tables_with_known_upstreams: int = 0
num_streams_with_known_upstreams: int = 0
num_upstream_lineage_edge_parsing_failed: int = 0
num_secure_views_missing_definition: int = 0
num_structured_property_templates_created: int = 0
Expand All @@ -131,6 +133,8 @@ class SnowflakeV2Report(
num_get_tags_for_object_queries: int = 0
num_get_tags_on_columns_for_table_queries: int = 0

num_get_streams_for_schema_queries: int = 0

rows_zero_objects_modified: int = 0

_processed_tags: MutableSet[str] = field(default_factory=set)
Expand All @@ -157,6 +161,8 @@ def report_entity_scanned(self, name: str, ent_type: str = "table") -> None:
return
self._scanned_tags.add(name)
self.tags_scanned += 1
elif ent_type == "stream":
self.streams_scanned += 1
else:
raise KeyError(f"Unknown entity {ent_type}.")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
from datahub.ingestion.source.sql.sql_generic import BaseColumn, BaseTable, BaseView
from datahub.utilities.file_backed_collections import FileBackedDict
from datahub.utilities.prefix_batch_builder import build_prefix_batches
from datahub.utilities.prefix_batch_builder import PrefixGroup, build_prefix_batches
from datahub.utilities.serialized_lru_cache import serialized_lru_cache

logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -118,6 +118,7 @@ class SnowflakeSchema:
comment: Optional[str]
tables: List[str] = field(default_factory=list)
views: List[str] = field(default_factory=list)
streams: List[str] = field(default_factory=list)
tags: Optional[List[SnowflakeTag]] = None


Expand All @@ -131,6 +132,29 @@ class SnowflakeDatabase:
tags: Optional[List[SnowflakeTag]] = None


@dataclass
class SnowflakeStream:
name: str
created: datetime
owner: str
source_type: str
type: str
stale: str
mode: str
invalid_reason: str
owner_role_type: str
database_name: str
schema_name: str
table_name: str
comment: Optional[str]
columns: List[SnowflakeColumn] = field(default_factory=list)
stale_after: Optional[datetime] = None
base_tables: Optional[str] = None
tags: Optional[List[SnowflakeTag]] = None
column_tags: Dict[str, List[SnowflakeTag]] = field(default_factory=dict)
last_altered: Optional[datetime] = None


class _SnowflakeTagCache:
def __init__(self) -> None:
# self._database_tags[<database_name>] = list of tags applied to database
Expand Down Expand Up @@ -208,6 +232,7 @@ def as_obj(self) -> Dict[str, Dict[str, int]]:
self.get_tables_for_database,
self.get_views_for_database,
self.get_columns_for_schema,
self.get_streams_for_database,
self.get_pk_constraints_for_schema,
self.get_fk_constraints_for_schema,
]
Expand Down Expand Up @@ -431,9 +456,18 @@ def get_columns_for_schema(
# For massive schemas, use a FileBackedDict to avoid memory issues.
columns = FileBackedDict()

object_batches = build_prefix_batches(
all_objects, max_batch_size=10000, max_groups_in_batch=5
)
# Single prefix table case (for streams)
if len(all_objects) == 1:
object_batches = [
[PrefixGroup(prefix=all_objects[0], names=[], exact_match=True)]
]
else:
# Build batches for full schema scan
object_batches = build_prefix_batches(
all_objects, max_batch_size=10000, max_groups_in_batch=5
)

# Process batches
for batch_index, object_batch in enumerate(object_batches):
if batch_index > 0:
logger.info(
Expand Down Expand Up @@ -611,3 +645,63 @@ def get_tags_on_columns_for_table(
tags[column_name].append(snowflake_tag)

return tags

@serialized_lru_cache(maxsize=1)
def get_streams_for_database(
self, db_name: str
) -> Dict[str, List[SnowflakeStream]]:
page_limit = SHOW_VIEWS_MAX_PAGE_SIZE

streams: Dict[str, List[SnowflakeStream]] = {}

first_iteration = True
stream_pagination_marker: Optional[str] = None
while first_iteration or stream_pagination_marker is not None:
cur = self.connection.query(
SnowflakeQuery.streams_for_database(
db_name,
limit=page_limit,
stream_pagination_marker=stream_pagination_marker,
)
)

first_iteration = False
stream_pagination_marker = None

result_set_size = 0
for stream in cur:
result_set_size += 1

stream_name = stream["name"]
schema_name = stream["schema_name"]
if schema_name not in streams:
streams[schema_name] = []
streams[stream["schema_name"]].append(
SnowflakeStream(
name=stream["name"],
created=stream["created_on"],
owner=stream["owner"],
comment=stream["comment"],
source_type=stream["source_type"],
type=stream["type"],
stale=stream["stale"],
mode=stream["mode"],
database_name=stream["database_name"],
schema_name=stream["schema_name"],
invalid_reason=stream["invalid_reason"],
owner_role_type=stream["owner_role_type"],
stale_after=stream["stale_after"],
table_name=stream["table_name"],
base_tables=stream["base_tables"],
last_altered=stream["created_on"],
)
)

if result_set_size >= page_limit:
# If we hit the limit, we need to send another request to get the next page.
logger.info(
f"Fetching next page of streams for {db_name} - after {stream_name}"
)
stream_pagination_marker = stream_name

return streams
Loading

0 comments on commit ac13f25

Please sign in to comment.