|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import difflib |
| 4 | +import logging |
| 5 | +from typing import TYPE_CHECKING, List, Literal, Optional, Set, Union |
| 6 | + |
| 7 | +import datahub.metadata.schema_classes as models |
| 8 | +from datahub.emitter.mcp import MetadataChangeProposalWrapper |
| 9 | +from datahub.errors import SdkUsageError |
| 10 | +from datahub.metadata.schema_classes import SchemaMetadataClass |
| 11 | +from datahub.metadata.urns import DatasetUrn, QueryUrn |
| 12 | +from datahub.sdk._shared import DatasetUrnOrStr |
| 13 | +from datahub.sdk._utils import DEFAULT_ACTOR_URN |
| 14 | +from datahub.sdk.dataset import ColumnLineageMapping, parse_cll_mapping |
| 15 | +from datahub.specific.dataset import DatasetPatchBuilder |
| 16 | +from datahub.sql_parsing.fingerprint_utils import generate_hash |
| 17 | +from datahub.utilities.ordered_set import OrderedSet |
| 18 | + |
| 19 | +if TYPE_CHECKING: |
| 20 | + from datahub.sdk.main_client import DataHubClient |
| 21 | + |
| 22 | +logger = logging.getLogger(__name__) |
| 23 | + |
| 24 | +_empty_audit_stamp = models.AuditStampClass( |
| 25 | + time=0, |
| 26 | + actor=DEFAULT_ACTOR_URN, |
| 27 | +) |
| 28 | + |
| 29 | + |
| 30 | +class LineageClient: |
| 31 | + def __init__(self, client: DataHubClient): |
| 32 | + self._client = client |
| 33 | + |
| 34 | + def _get_fields_from_dataset_urn(self, dataset_urn: DatasetUrn) -> Set[str]: |
| 35 | + schema_metadata = self._client._graph.get_aspect( |
| 36 | + str(dataset_urn), SchemaMetadataClass |
| 37 | + ) |
| 38 | + if schema_metadata is None: |
| 39 | + return Set() |
| 40 | + |
| 41 | + return {field.fieldPath for field in schema_metadata.fields} |
| 42 | + |
| 43 | + @classmethod |
| 44 | + def _get_strict_column_lineage( |
| 45 | + cls, |
| 46 | + upstream_fields: Set[str], |
| 47 | + downstream_fields: Set[str], |
| 48 | + ) -> ColumnLineageMapping: |
| 49 | + """Find matches between upstream and downstream fields with case-insensitive matching.""" |
| 50 | + strict_column_lineage: ColumnLineageMapping = {} |
| 51 | + |
| 52 | + # Create case-insensitive mapping of upstream fields |
| 53 | + case_insensitive_map = {field.lower(): field for field in upstream_fields} |
| 54 | + |
| 55 | + # Match downstream fields using case-insensitive comparison |
| 56 | + for downstream_field in downstream_fields: |
| 57 | + lower_field = downstream_field.lower() |
| 58 | + if lower_field in case_insensitive_map: |
| 59 | + # Use the original case of the upstream field |
| 60 | + strict_column_lineage[downstream_field] = [ |
| 61 | + case_insensitive_map[lower_field] |
| 62 | + ] |
| 63 | + |
| 64 | + return strict_column_lineage |
| 65 | + |
| 66 | + @classmethod |
| 67 | + def _get_fuzzy_column_lineage( |
| 68 | + cls, |
| 69 | + upstream_fields: Set[str], |
| 70 | + downstream_fields: Set[str], |
| 71 | + ) -> ColumnLineageMapping: |
| 72 | + """Generate fuzzy matches between upstream and downstream fields.""" |
| 73 | + |
| 74 | + # Simple normalization function for better matching |
| 75 | + def normalize(s: str) -> str: |
| 76 | + return s.lower().replace("_", "") |
| 77 | + |
| 78 | + # Create normalized lookup for upstream fields |
| 79 | + normalized_upstream = {normalize(field): field for field in upstream_fields} |
| 80 | + |
| 81 | + fuzzy_column_lineage = {} |
| 82 | + for downstream_field in downstream_fields: |
| 83 | + # Try exact match first |
| 84 | + if downstream_field in upstream_fields: |
| 85 | + fuzzy_column_lineage[downstream_field] = [downstream_field] |
| 86 | + continue |
| 87 | + |
| 88 | + # Try normalized match |
| 89 | + norm_downstream = normalize(downstream_field) |
| 90 | + if norm_downstream in normalized_upstream: |
| 91 | + fuzzy_column_lineage[downstream_field] = [ |
| 92 | + normalized_upstream[norm_downstream] |
| 93 | + ] |
| 94 | + continue |
| 95 | + |
| 96 | + # If no direct match, find closest match using similarity |
| 97 | + matches = difflib.get_close_matches( |
| 98 | + norm_downstream, |
| 99 | + normalized_upstream.keys(), |
| 100 | + n=1, # Return only the best match |
| 101 | + cutoff=0.8, # Adjust cutoff for sensitivity |
| 102 | + ) |
| 103 | + |
| 104 | + if matches: |
| 105 | + fuzzy_column_lineage[downstream_field] = [ |
| 106 | + normalized_upstream[matches[0]] |
| 107 | + ] |
| 108 | + |
| 109 | + return fuzzy_column_lineage |
| 110 | + |
| 111 | + def add_dataset_copy_lineage( |
| 112 | + self, |
| 113 | + *, |
| 114 | + upstream: DatasetUrnOrStr, |
| 115 | + downstream: DatasetUrnOrStr, |
| 116 | + column_lineage: Union[ |
| 117 | + None, ColumnLineageMapping, Literal["auto_fuzzy", "auto_strict"] |
| 118 | + ] = "auto_fuzzy", |
| 119 | + ) -> None: |
| 120 | + upstream = DatasetUrn.from_string(upstream) |
| 121 | + downstream = DatasetUrn.from_string(downstream) |
| 122 | + |
| 123 | + if column_lineage is None: |
| 124 | + cll = None |
| 125 | + elif column_lineage in ["auto_fuzzy", "auto_strict"]: |
| 126 | + upstream_schema = self._get_fields_from_dataset_urn(upstream) |
| 127 | + downstream_schema = self._get_fields_from_dataset_urn(downstream) |
| 128 | + if column_lineage == "auto_fuzzy": |
| 129 | + mapping = self._get_fuzzy_column_lineage( |
| 130 | + upstream_schema, downstream_schema |
| 131 | + ) |
| 132 | + else: |
| 133 | + mapping = self._get_strict_column_lineage( |
| 134 | + upstream_schema, downstream_schema |
| 135 | + ) |
| 136 | + cll = parse_cll_mapping( |
| 137 | + upstream=upstream, |
| 138 | + downstream=downstream, |
| 139 | + cll_mapping=mapping, |
| 140 | + ) |
| 141 | + elif isinstance(column_lineage, dict): |
| 142 | + cll = parse_cll_mapping( |
| 143 | + upstream=upstream, |
| 144 | + downstream=downstream, |
| 145 | + cll_mapping=column_lineage, |
| 146 | + ) |
| 147 | + |
| 148 | + updater = DatasetPatchBuilder(str(downstream)) |
| 149 | + updater.add_upstream_lineage( |
| 150 | + models.UpstreamClass( |
| 151 | + dataset=str(upstream), |
| 152 | + type=models.DatasetLineageTypeClass.COPY, |
| 153 | + ) |
| 154 | + ) |
| 155 | + for cl in cll or []: |
| 156 | + updater.add_fine_grained_upstream_lineage(cl) |
| 157 | + |
| 158 | + self._client.entities.update(updater) |
| 159 | + |
| 160 | + def add_dataset_transform_lineage( |
| 161 | + self, |
| 162 | + *, |
| 163 | + upstream: DatasetUrnOrStr, |
| 164 | + downstream: DatasetUrnOrStr, |
| 165 | + column_lineage: Optional[ColumnLineageMapping] = None, |
| 166 | + query_text: Optional[str] = None, |
| 167 | + ) -> None: |
| 168 | + upstream = DatasetUrn.from_string(upstream) |
| 169 | + downstream = DatasetUrn.from_string(downstream) |
| 170 | + |
| 171 | + cll = None |
| 172 | + if column_lineage is not None: |
| 173 | + cll = parse_cll_mapping( |
| 174 | + upstream=upstream, |
| 175 | + downstream=downstream, |
| 176 | + cll_mapping=column_lineage, |
| 177 | + ) |
| 178 | + |
| 179 | + fields_involved = OrderedSet([str(upstream), str(downstream)]) |
| 180 | + if cll is not None: |
| 181 | + for c in cll: |
| 182 | + for field in c.upstreams or []: |
| 183 | + fields_involved.add(field) |
| 184 | + for field in c.downstreams or []: |
| 185 | + fields_involved.add(field) |
| 186 | + |
| 187 | + query_urn = None |
| 188 | + query_entity = None |
| 189 | + if query_text: |
| 190 | + # Eventually we might want to use our regex-based fingerprinting instead. |
| 191 | + fingerprint = generate_hash(query_text) |
| 192 | + query_urn = QueryUrn(fingerprint).urn() |
| 193 | + |
| 194 | + from datahub.sql_parsing.sql_parsing_aggregator import make_query_subjects |
| 195 | + |
| 196 | + query_entity = MetadataChangeProposalWrapper.construct_many( |
| 197 | + query_urn, |
| 198 | + aspects=[ |
| 199 | + models.QueryPropertiesClass( |
| 200 | + statement=models.QueryStatementClass( |
| 201 | + value=query_text, language=models.QueryLanguageClass.SQL |
| 202 | + ), |
| 203 | + source=models.QuerySourceClass.SYSTEM, |
| 204 | + created=_empty_audit_stamp, |
| 205 | + lastModified=_empty_audit_stamp, |
| 206 | + ), |
| 207 | + make_query_subjects(list(fields_involved)), |
| 208 | + ], |
| 209 | + ) |
| 210 | + |
| 211 | + updater = DatasetPatchBuilder(str(downstream)) |
| 212 | + updater.add_upstream_lineage( |
| 213 | + models.UpstreamClass( |
| 214 | + dataset=str(upstream), |
| 215 | + type=models.DatasetLineageTypeClass.TRANSFORMED, |
| 216 | + query=query_urn, |
| 217 | + ) |
| 218 | + ) |
| 219 | + for cl in cll or []: |
| 220 | + cl.query = query_urn |
| 221 | + updater.add_fine_grained_upstream_lineage(cl) |
| 222 | + |
| 223 | + # Throw if the dataset does not exist. |
| 224 | + # We need to manually call .build() instead of reusing client.update() |
| 225 | + # so that we make just one emit_mcps call. |
| 226 | + if not self._client._graph.exists(updater.urn): |
| 227 | + raise SdkUsageError( |
| 228 | + f"Dataset {updater.urn} does not exist, and hence cannot be updated." |
| 229 | + ) |
| 230 | + mcps: List[ |
| 231 | + Union[MetadataChangeProposalWrapper, models.MetadataChangeProposalClass] |
| 232 | + ] = list(updater.build()) |
| 233 | + if query_entity: |
| 234 | + mcps.extend(query_entity) |
| 235 | + self._client._graph.emit_mcps(mcps) |
0 commit comments