Skip to content

Commit

Permalink
fix lineage bug & add logging & fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
yoonhyejin committed Jan 28, 2025
1 parent bc19996 commit e0b0faf
Show file tree
Hide file tree
Showing 11 changed files with 70 additions and 41 deletions.
3 changes: 2 additions & 1 deletion metadata-ingestion/examples/ml/add_input_dataset_to_run.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
from typing import Optional

from datahub.api.entities.dataset.dataset import Dataset
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand All @@ -7,7 +8,7 @@
DataProcessInstanceInput,
)
from datahub.metadata.schema_classes import ChangeTypeClass
from typing import Optional


def create_dataset(
platform: str,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
import argparse
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from typing import Optional


def add_model_version_to_model(
model_urn: str, model_group_urn: str, token: Optional[str], server_url: str = "http://localhost:8080"
model_urn: str,
model_group_urn: str,
token: Optional[str],
server_url: str = "http://localhost:8080",
) -> None:
# Create model properties
model_properties = models.MLModelPropertiesClass(groups=[model_group_urn])
Expand Down
22 changes: 11 additions & 11 deletions metadata-ingestion/examples/ml/add_output_dataset_to_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import argparse
from typing import Optional, List
from typing import List

from datahub.api.entities.dataset.dataset import Dataset
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand All @@ -11,11 +11,11 @@


def create_dataset(
platform: str,
name: str,
token: str,
description: str = "",
server_url: str = "http://localhost:8080",
platform: str,
name: str,
token: str,
description: str = "",
server_url: str = "http://localhost:8080",
) -> str:
"""Create a dataset in DataHub and return its URN.
Expand Down Expand Up @@ -50,10 +50,10 @@ def create_dataset(


def add_output_datasets_to_run(
run_urn: str,
dataset_urns: List[str],
token: str,
server_url: str = "http://localhost:8080",
run_urn: str,
dataset_urns: List[str],
token: str,
server_url: str = "http://localhost:8080",
) -> None:
"""Add output datasets to a data process instance run.
Expand Down Expand Up @@ -122,4 +122,4 @@ def add_output_datasets_to_run(
run_urn="urn:li:dataProcessInstance:c29762bd7cc66e35414d95350454e542",
dataset_urns=dataset_urns,
token=args.token,
)
)
3 changes: 2 additions & 1 deletion metadata-ingestion/examples/ml/add_run_to_experiment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from typing import Optional


def add_run_to_experiment(
run_urn: str,
Expand Down
7 changes: 5 additions & 2 deletions metadata-ingestion/examples/ml/add_run_to_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import argparse
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from typing import Optional


def add_run_to_model(
model_urn: str, run_urn: str, token: Optional[str], server_url: str = "http://localhost:8080"
model_urn: str,
run_urn: str,
token: Optional[str],
server_url: str = "http://localhost:8080",
) -> None:
# Create model properties
model_properties = models.MLModelPropertiesClass(
Expand Down
3 changes: 2 additions & 1 deletion metadata-ingestion/examples/ml/add_run_to_model_group.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from typing import Optional


def add_run_to_model_group(
model_group_urn: str,
Expand Down
3 changes: 2 additions & 1 deletion metadata-ingestion/examples/ml/create_ml_experiment.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
from datahub.ingestion.graph.client import DatahubClientConfig, DataHubGraph
from datahub.metadata.urns import ContainerUrn, DataPlatformUrn
from typing import Optional


def create_experiment(
experiment_id: str,
Expand Down
8 changes: 6 additions & 2 deletions metadata-ingestion/examples/ml/create_ml_training_run.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import time
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand All @@ -8,10 +9,13 @@
AuditStampClass,
DataProcessInstancePropertiesClass,
)
from typing import Optional


def create_minimal_training_run(
run_id: str, name: str, token: Optional[str], server_url: str = "http://localhost:8080"
run_id: str,
name: str,
token: Optional[str],
server_url: str = "http://localhost:8080",
) -> None:
# Create a container key (required for DataProcessInstance)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import time
from typing import Optional

import datahub.metadata.schema_classes as models
from datahub.api.entities.dataprocess.dataprocess_instance import DataProcessInstance
Expand All @@ -9,7 +10,7 @@
AuditStampClass,
DataProcessInstancePropertiesClass,
)
from typing import Optional


def create_minimal_training_run(
run_id: str,
Expand Down
43 changes: 27 additions & 16 deletions metadata-ingestion/examples/ml/mlflow_dh_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import time
from enum import Enum
from typing import Any, Dict, List, Optional, Union

import datahub.metadata.schema_classes as models
Expand All @@ -13,9 +12,8 @@
)
from datahub.metadata.schema_classes import (
ChangeTypeClass,
DataProcessRunStatusClass,
DataProcessInstanceRunResultClass,
AuditStampClass,
DataProcessRunStatusClass,
)
from datahub.metadata.urns import (
ContainerUrn,
Expand All @@ -25,6 +23,10 @@
VersionSetUrn,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class MLflowDatahubClient:
"""Client for creating and managing MLflow metadata in DataHub."""

Expand Down Expand Up @@ -55,9 +57,8 @@ def _create_timestamp(
)

def _emit_mcps(
self,
mcps: Union[
List[MetadataChangeProposalWrapper], MetadataChangeProposalWrapper],
self,
mcps: Union[List[MetadataChangeProposalWrapper], MetadataChangeProposalWrapper],
) -> None:
"""Helper to emit MCPs with proper connection handling"""
if not isinstance(mcps, list):
Expand All @@ -73,7 +74,7 @@ def _get_aspect(
try:
return self.graph.get_aspect(entity_urn=entity_urn, aspect_type=aspect_type)
except Exception as e:
logging.warning(f"Could not fetch aspect for {entity_urn}: {e}")
logger.warning(f"Could not fetch aspect for {entity_urn}: {e}")
return default_constructor() if default_constructor else None

def _create_properties_class(
Expand Down Expand Up @@ -194,8 +195,7 @@ def _create_run_event(

if result:
event_args["result"] = DataProcessInstanceRunResultClass(
type=result,
nativeResultType=str(result)
type=result, nativeResultType=str(result)
)
if duration_millis:
event_args["durationMillis"] = duration_millis
Expand All @@ -220,6 +220,7 @@ def create_model_group(
str(model_group_urn), properties, "mlModelGroup", "mlModelGroupProperties"
)
self._emit_mcps(mcp)
logger.info(f"Created model group: {model_group_urn}")
return str(model_group_urn)

def create_model(
Expand Down Expand Up @@ -269,7 +270,9 @@ def create_model(
)

mcps = [
self._create_mcp(str(model_urn), properties, "mlModel", "mlModelProperties"),
self._create_mcp(
str(model_urn), properties, "mlModel", "mlModelProperties"
),
self._create_mcp(
str(version_set_urn),
version_set_properties,
Expand All @@ -281,6 +284,7 @@ def create_model(
),
]
self._emit_mcps(mcps)
logger.info(f"Created model: {model_urn}")
return str(model_urn)

def create_experiment(
Expand All @@ -307,6 +311,7 @@ def create_experiment(
aspects=[container_subtype, properties, browse_path, platform_instance],
)
self._emit_mcps(mcps)
logger.info(f"Created experiment: {container_urn}")
return str(container_urn)

def create_training_run(
Expand Down Expand Up @@ -345,8 +350,7 @@ def create_training_run(
# Create events
aspects.append(
self._create_run_event(
status=DataProcessRunStatusClass.STARTED,
timestamp=start_ts
status=DataProcessRunStatusClass.STARTED, timestamp=start_ts
)
)

Expand All @@ -363,6 +367,7 @@ def create_training_run(
# Create and emit MCPs
mcps = [self._create_mcp(dpi_urn, aspect) for aspect in aspects]
self._emit_mcps(mcps)
logger.info(f"Created training run: {dpi_urn}")
return dpi_urn

def create_dataset(self, name: str, platform: str, **kwargs: Any) -> str:
Expand All @@ -379,20 +384,22 @@ def add_run_to_model(self, model_urn: str, run_urn: str) -> None:
self._update_entity_properties(
entity_urn=model_urn,
aspect_type=models.MLModelPropertiesClass,
updates={"trainingJobs": run_urn, "downstreamJobs": run_urn},
updates={"trainingJobs": run_urn},
entity_type="mlModel",
skip_properties=["trainingJobs", "downstreamJobs"],
skip_properties=["trainingJobs"],
)
logger.info(f"Added run {run_urn} to model {model_urn}")

def add_run_to_model_group(self, model_group_urn: str, run_urn: str) -> None:
"""Add a run to a model group while preserving existing properties."""
self._update_entity_properties(
entity_urn=model_group_urn,
aspect_type=models.MLModelGroupPropertiesClass,
updates={"trainingJobs": run_urn, "downstreamJobs": run_urn},
updates={"trainingJobs": run_urn},
entity_type="mlModelGroup",
skip_properties=["trainingJobs", "downstreamJobs"],
skip_properties=["trainingJobs"],
)
logger.info(f"Added run {run_urn} to model group {model_group_urn}")

def add_model_to_model_group(self, model_urn: str, group_urn: str) -> None:
"""Add a model to a group while preserving existing properties"""
Expand All @@ -403,13 +410,15 @@ def add_model_to_model_group(self, model_urn: str, group_urn: str) -> None:
entity_type="mlModel",
skip_properties=["groups"],
)
logger.info(f"Added model {model_urn} to group {group_urn}")

def add_run_to_experiment(self, run_urn: str, experiment_urn: str) -> None:
"""Add a run to an experiment"""
mcp = self._create_mcp(
entity_urn=run_urn, aspect=models.ContainerClass(container=experiment_urn)
)
self._emit_mcps(mcp)
logger.info(f"Added run {run_urn} to experiment {experiment_urn}")

def add_input_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
"""Add input datasets to a run"""
Expand All @@ -420,6 +429,7 @@ def add_input_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> No
aspect=DataProcessInstanceInput(inputs=dataset_urns),
)
self._emit_mcps(mcp)
logger.info(f"Added input datasets to run {run_urn}")

def add_output_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> None:
"""Add output datasets to a run"""
Expand All @@ -430,3 +440,4 @@ def add_output_datasets_to_run(self, run_urn: str, dataset_urns: List[str]) -> N
aspect=DataProcessInstanceOutput(outputs=dataset_urns),
)
self._emit_mcps(mcp)
logger.info(f"Added output datasets to run {run_urn}")
8 changes: 5 additions & 3 deletions metadata-ingestion/examples/ml/mlflow_dh_client_sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import argparse

from mlflow_dh_client import MLflowDatahubClient

import datahub.metadata.schema_classes as models
from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import RunResultType

Expand Down Expand Up @@ -68,16 +70,16 @@
)

run_urn = client.create_training_run(
run_id="simple_training_run_4",
run_id="simple_training_run",
properties=models.DataProcessInstancePropertiesClass(
name="Simple Training Run 4",
name="Simple Training Run",
created=models.AuditStampClass(
time=1628580000000, actor="urn:li:corpuser:datahub"
),
customProperties={"team": "forecasting"},
),
training_run_properties=models.MLTrainingRunPropertiesClass(
id="simple_training_run_4",
id="simple_training_run",
outputUrls=["s3://my-bucket/output"],
trainingMetrics=[models.MLMetricClass(name="accuracy", value="0.9")],
hyperParams=[models.MLHyperParamClass(name="learning_rate", value="0.01")],
Expand Down

0 comments on commit e0b0faf

Please sign in to comment.