Skip to content

Commit 72aab9f

Browse files
yoonhyejinhsheth2
andauthored
feat(sdk): add sdk lineage client (#13244)
Co-authored-by: Harshal Sheth <[email protected]>
1 parent b75dbaa commit 72aab9f

16 files changed

+965
-51
lines changed

metadata-ingestion/src/datahub/sdk/_shared.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
Urn,
4242
VersionSetUrn,
4343
)
44-
from datahub.sdk._utils import add_list_unique, remove_list_unique
44+
from datahub.sdk._utils import DEFAULT_ACTOR_URN, add_list_unique, remove_list_unique
4545
from datahub.sdk.entity import Entity
4646
from datahub.utilities.urns.error import InvalidUrnError
4747

@@ -54,8 +54,6 @@
5454

5555
ActorUrn: TypeAlias = Union[CorpUserUrn, CorpGroupUrn]
5656

57-
_DEFAULT_ACTOR_URN = CorpUserUrn("__ingestion").urn()
58-
5957
TrainingMetricsInputType: TypeAlias = Union[
6058
List[models.MLMetricClass], Dict[str, Optional[str]]
6159
]
@@ -475,7 +473,7 @@ def _parse_glossary_term_association_class(
475473
def _terms_audit_stamp(self) -> models.AuditStampClass:
476474
return models.AuditStampClass(
477475
time=0,
478-
actor=_DEFAULT_ACTOR_URN,
476+
actor=DEFAULT_ACTOR_URN,
479477
)
480478

481479
def set_terms(self, terms: TermsInputType) -> None:
@@ -563,7 +561,7 @@ def links(self) -> Optional[List[models.InstitutionalMemoryMetadataClass]]:
563561
def _institutional_memory_audit_stamp(self) -> models.AuditStampClass:
564562
return models.AuditStampClass(
565563
time=0,
566-
actor=_DEFAULT_ACTOR_URN,
564+
actor=DEFAULT_ACTOR_URN,
567565
)
568566

569567
@classmethod

metadata-ingestion/src/datahub/sdk/_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from typing import Any, Callable, List, Protocol, TypeVar
22

33
from datahub.errors import ItemNotFoundError
4+
from datahub.metadata.urns import CorpUserUrn
5+
6+
# TODO: Change __ingestion to _ingestion.
7+
DEFAULT_ACTOR_URN = CorpUserUrn("__ingestion").urn()
48

59

610
class _SupportsEq(Protocol):

metadata-ingestion/src/datahub/sdk/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def _parse_upstream_input(
8787
assert_never(upstream_input)
8888

8989

90-
def _parse_cll_mapping(
90+
def parse_cll_mapping(
9191
*,
9292
upstream: DatasetUrnOrStr,
9393
downstream: DatasetUrnOrStr,
@@ -142,7 +142,7 @@ def _parse_upstream_lineage_input(
142142
)
143143
)
144144
cll.extend(
145-
_parse_cll_mapping(
145+
parse_cll_mapping(
146146
upstream=dataset_urn,
147147
downstream=downstream_urn,
148148
cll_mapping=column_lineage,
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
from __future__ import annotations
2+
3+
import difflib
4+
import logging
5+
from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union
6+
7+
import datahub.metadata.schema_classes as models
8+
from datahub.emitter.mcp import MetadataChangeProposalWrapper
9+
from datahub.errors import SdkUsageError
10+
from datahub.metadata.schema_classes import SchemaMetadataClass
11+
from datahub.metadata.urns import DatasetUrn, QueryUrn
12+
from datahub.sdk._shared import DatasetUrnOrStr
13+
from datahub.sdk._utils import DEFAULT_ACTOR_URN
14+
from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping
15+
from datahub.specific.dataset import DatasetPatchBuilder
16+
from datahub.sql_parsing.fingerprint_utils import generate_hash
17+
from datahub.utilities.ordered_set import OrderedSet
18+
19+
if TYPE_CHECKING:
20+
from datahub.sdk.main_client import DataHubClient
21+
22+
logger = logging.getLogger(__name__)
23+
24+
_empty_audit_stamp = models.AuditStampClass(
25+
time=0,
26+
actor=DEFAULT_ACTOR_URN,
27+
)
28+
29+
30+
class LineageClient:
31+
def __init__(self, client: DataHubClient):
32+
self._client = client
33+
34+
def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]:
35+
schema_metadata = self._client._graph.get_aspect(
36+
str(dataset_urn), SchemaMetadataClass
37+
)
38+
if schema_metadata is None:
39+
return Set()
40+
41+
return {field.fieldPath for field in schema_metadata.fields}
42+
43+
@classmethod
44+
def _get_strict_column_lineage(
45+
cls,
46+
upstream_fields: Set[str],
47+
downstream_fields: Set[str],
48+
) -> ColumnLineageMapping:
49+
"""Find matches between upstream and downstream fields with case-insensitive matching."""
50+
strict_column_lineage: ColumnLineageMapping = {}
51+
52+
# Create case-insensitive mapping of upstream fields
53+
case_insensitive_map = {field.lower(): field for field in upstream_fields}
54+
55+
# Match downstream fields using case-insensitive comparison
56+
for downstream_field in downstream_fields:
57+
lower_field = downstream_field.lower()
58+
if lower_field in case_insensitive_map:
59+
# Use the original case of the upstream field
60+
strict_column_lineage[downstream_field] = [
61+
case_insensitive_map[lower_field]
62+
]
63+
64+
return strict_column_lineage
65+
66+
@classmethod
67+
def _get_fuzzy_column_lineage(
68+
cls,
69+
upstream_fields: Set[str],
70+
downstream_fields: Set[str],
71+
) -> ColumnLineageMapping:
72+
"""Generate fuzzy matches between upstream and downstream fields."""
73+
74+
# Simple normalization function for better matching
75+
def normalize(s: str) -> str:
76+
return s.lower().replace("_", "")
77+
78+
# Create normalized lookup for upstream fields
79+
normalized_upstream = {normalize(field): field for field in upstream_fields}
80+
81+
fuzzy_column_lineage = {}
82+
for downstream_field in downstream_fields:
83+
# Try exact match first
84+
if downstream_field in upstream_fields:
85+
fuzzy_column_lineage[downstream_field] = [downstream_field]
86+
continue
87+
88+
# Try normalized match
89+
norm_downstream = normalize(downstream_field)
90+
if norm_downstream in normalized_upstream:
91+
fuzzy_column_lineage[downstream_field] = [
92+
normalized_upstream[norm_downstream]
93+
]
94+
continue
95+
96+
# If no direct match, find closest match using similarity
97+
matches = difflib.get_close_matches(
98+
norm_downstream,
99+
normalized_upstream.keys(),
100+
n=1, # Return only the best match
101+
cutoff=0.8, # Adjust cutoff for sensitivity
102+
)
103+
104+
if matches:
105+
fuzzy_column_lineage[downstream_field] = [
106+
normalized_upstream[matches[0]]
107+
]
108+
109+
return fuzzy_column_lineage
110+
111+
def add_dataset_copy_lineage(
112+
self,
113+
*,
114+
upstream: DatasetUrnOrStr,
115+
downstream: DatasetUrnOrStr,
116+
column_lineage: Union[
117+
None, ColumnLineageMapping, Literal["auto_fuzzy", "auto_strict"]
118+
] = "auto_fuzzy",
119+
) -> None:
120+
upstream = DatasetUrn.from_string(upstream)
121+
downstream = DatasetUrn.from_string(downstream)
122+
123+
if column_lineage is None:
124+
cll = None
125+
elif column_lineage in ["auto_fuzzy", "auto_strict"]:
126+
upstream_schema = self._get_fields_from_dataset_urn(upstream)
127+
downstream_schema = self._get_fields_from_dataset_urn(downstream)
128+
if column_lineage == "auto_fuzzy":
129+
mapping = self._get_fuzzy_column_lineage(
130+
upstream_schema, downstream_schema
131+
)
132+
else:
133+
mapping = self._get_strict_column_lineage(
134+
upstream_schema, downstream_schema
135+
)
136+
cll = parse_cll_mapping(
137+
upstream=upstream,
138+
downstream=downstream,
139+
cll_mapping=mapping,
140+
)
141+
elif isinstance(column_lineage, dict):
142+
cll = parse_cll_mapping(
143+
upstream=upstream,
144+
downstream=downstream,
145+
cll_mapping=column_lineage,
146+
)
147+
148+
updater = DatasetPatchBuilder(str(downstream))
149+
updater.add_upstream_lineage(
150+
models.UpstreamClass(
151+
dataset=str(upstream),
152+
type=models.DatasetLineageTypeClass.COPY,
153+
)
154+
)
155+
for cl in cll or []:
156+
updater.add_fine_grained_upstream_lineage(cl)
157+
158+
self._client.entities.update(updater)
159+
160+
def add_dataset_transform_lineage(
161+
self,
162+
*,
163+
upstream: DatasetUrnOrStr,
164+
downstream: DatasetUrnOrStr,
165+
column_lineage: Optional[ColumnLineageMapping] = None,
166+
query_text: Optional[str] = None,
167+
) -> None:
168+
upstream = DatasetUrn.from_string(upstream)
169+
downstream = DatasetUrn.from_string(downstream)
170+
171+
cll = None
172+
if column_lineage is not None:
173+
cll = parse_cll_mapping(
174+
upstream=upstream,
175+
downstream=downstream,
176+
cll_mapping=column_lineage,
177+
)
178+
179+
fields_involved = OrderedSet([str(upstream), str(downstream)])
180+
if cll is not None:
181+
for c in cll:
182+
for field in c.upstreams or []:
183+
fields_involved.add(field)
184+
for field in c.downstreams or []:
185+
fields_involved.add(field)
186+
187+
query_urn = None
188+
query_entity = None
189+
if query_text:
190+
# Eventually we might want to use our regex-based fingerprinting instead.
191+
fingerprint = generate_hash(query_text)
192+
query_urn = QueryUrn(fingerprint).urn()
193+
194+
from datahub.sql_parsing.sql_parsing_aggregator import make_query_subjects
195+
196+
query_entity = MetadataChangeProposalWrapper.construct_many(
197+
query_urn,
198+
aspects=[
199+
models.QueryPropertiesClass(
200+
statement=models.QueryStatementClass(
201+
value=query_text, language=models.QueryLanguageClass.SQL
202+
),
203+
source=models.QuerySourceClass.SYSTEM,
204+
created=_empty_audit_stamp,
205+
lastModified=_empty_audit_stamp,
206+
),
207+
make_query_subjects(list(fields_involved)),
208+
],
209+
)
210+
211+
updater = DatasetPatchBuilder(str(downstream))
212+
updater.add_upstream_lineage(
213+
models.UpstreamClass(
214+
dataset=str(upstream),
215+
type=models.DatasetLineageTypeClass.TRANSFORMED,
216+
query=query_urn,
217+
)
218+
)
219+
for cl in cll or []:
220+
cl.query = query_urn
221+
updater.add_fine_grained_upstream_lineage(cl)
222+
223+
# Throw if the dataset does not exist.
224+
# We need to manually call .build() instead of reusing client.update()
225+
# so that we make just one emit_mcps call.
226+
if not self._client._graph.exists(updater.urn):
227+
raise SdkUsageError(
228+
f"Dataset {updater.urn} does not exist, and hence cannot be updated."
229+
)
230+
mcps: List[
231+
Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass]
232+
] = list(updater.build())
233+
if query_entity:
234+
mcps.extend(query_entity)
235+
self._client._graph.emit_mcps(mcps)

metadata-ingestion/src/datahub/sdk/main_client.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datahub.ingestion.graph.client import DataHubGraph, get_default_graph
77
from datahub.ingestion.graph.config import DatahubClientConfig
88
from datahub.sdk.entity_client import EntityClient
9+
from datahub.sdk.lineage_client import LineageClient
910
from datahub.sdk.resolver_client import ResolverClient
1011
from datahub.sdk.search_client import SearchClient
1112

@@ -99,4 +100,6 @@ def resolve(self) -> ResolverClient:
99100
def search(self) -> SearchClient:
100101
return SearchClient(self)
101102

102-
# TODO: lineage client
103+
@property
104+
def lineage(self) -> LineageClient:
105+
return LineageClient(self)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import hashlib
2+
3+
4+
def generate_hash(text: str) -> str:
5+
# Once we move to Python 3.9+, we can set `usedforsecurity=False`.
6+
return hashlib.sha256(text.encode("utf-8")).hexdigest()

0 commit comments

Comments
 (0)