1
+ import time
2
+ from dataclasses import dataclass
3
+ from datetime import datetime , timedelta
4
+ from typing import Iterable , List , Optional , Union
5
+ import random
6
+
7
+ import datahub .metadata .schema_classes as models
8
+ from datahub .api .entities .datajob import DataFlow , DataJob
9
+ from datahub .api .entities .dataprocess .dataprocess_instance import (
10
+ DataProcessInstance ,
11
+ InstanceRunResult ,
12
+ )
13
+ from datahub .api .entities .dataset .dataset import Dataset
14
+ from datahub .emitter .mcp import MetadataChangeProposalWrapper
15
+ from datahub .ingestion .graph .client import DataHubGraph , DatahubClientConfig
16
+ from datahub .metadata .urns import DatasetUrn , DataPlatformUrn , MlModelGroupUrn , MlModelUrn , VersionSetUrn
17
+ from datahub .emitter .mcp_builder import ContainerKey
18
+
19
+ ORCHESTRATOR_MLFLOW = "mlflow"
20
+ ORCHESTRATOR_AIRFLOW = "airflow"
21
+
22
+
23
+ class ContainerKeyWithId (ContainerKey ):
24
+ id : str
25
+
26
+
27
+ @dataclass
28
+ class Container :
29
+ key : ContainerKeyWithId
30
+ subtype : str
31
+ name : Optional [str ] = None
32
+ description : Optional [str ] = None
33
+
34
+ def generate_mcp (
35
+ self ,
36
+ ) -> Iterable [
37
+ Union [models .MetadataChangeProposalClass , MetadataChangeProposalWrapper ]
38
+ ]:
39
+ container_urn = self .key .as_urn ()
40
+ container_subtype = models .SubTypesClass (typeNames = [self .subtype ])
41
+ container_info = models .ContainerPropertiesClass (
42
+ name = self .name or self .key .id ,
43
+ description = self .description ,
44
+ customProperties = {},
45
+ )
46
+ browse_path = models .BrowsePathsV2Class (path = [])
47
+ dpi = models .DataPlatformInstanceClass (
48
+ platform = self .key .platform ,
49
+ instance = self .key .instance ,
50
+ )
51
+
52
+ yield from MetadataChangeProposalWrapper .construct_many (
53
+ entityUrn = container_urn ,
54
+ aspects = [container_subtype , container_info , browse_path , dpi ],
55
+ )
56
+
57
+
58
+ def create_model (
59
+ model_name : str ,
60
+ model_group_urn : str ,
61
+ data_process_instance_urn : str ,
62
+ tags : List [str ],
63
+ version_aliases : List [str ],
64
+ index : int ,
65
+ training_metrics : List [models .MLMetricClass ],
66
+ hyper_params : List [models .MLHyperParamClass ],
67
+ model_description : str ,
68
+ created_at : int ,
69
+ ) -> Iterable [MetadataChangeProposalWrapper ]:
70
+ model_urn = MlModelUrn (platform = "mlflow" , name = model_name )
71
+
72
+ # Create model properties
73
+ model_info = models .MLModelPropertiesClass (
74
+ description = model_description ,
75
+ version = models .VersionTagClass (versionTag = f"{ index } " ),
76
+ groups = [str (model_group_urn )],
77
+ trainingJobs = [str (data_process_instance_urn )],
78
+ date = created_at ,
79
+ tags = tags ,
80
+ trainingMetrics = training_metrics ,
81
+ hyperParams = hyper_params ,
82
+ created = models .TimeStampClass (
83
+ time = created_at ,
84
+ actor = "urn:li:corpuser:datahub"
85
+ ),
86
+ lastModified = models .TimeStampClass (
87
+ time = created_at ,
88
+ actor = "urn:li:corpuser:datahub"
89
+ ),
90
+ )
91
+
92
+ # Create version set
93
+ version_set_urn = VersionSetUrn (id = f"mlmodel_{ model_name } _versions" , entity_type = "mlModel" )
94
+ version_entity = models .VersionSetPropertiesClass (
95
+ latest = str (model_urn ),
96
+ versioningScheme = "ALPHANUMERIC_GENERATED_BY_DATAHUB" ,
97
+ )
98
+
99
+ # Create version properties
100
+ model_version_info = models .VersionPropertiesClass (
101
+ version = models .VersionTagClass (versionTag = f"{ index } " ),
102
+ versionSet = str (version_set_urn ),
103
+ aliases = [models .VersionTagClass (versionTag = alias ) for alias in version_aliases ],
104
+ sortId = "AAAAAAAA" ,
105
+ )
106
+
107
+ # Yield all MCPs
108
+ yield MetadataChangeProposalWrapper (
109
+ entityUrn = str (version_set_urn ),
110
+ entityType = "versionSet" ,
111
+ aspectName = "versionSetProperties" ,
112
+ aspect = version_entity ,
113
+ changeType = models .ChangeTypeClass .UPSERT
114
+ )
115
+
116
+ yield MetadataChangeProposalWrapper (
117
+ entityUrn = str (model_urn ),
118
+ entityType = "mlModel" ,
119
+ aspectName = "versionProperties" ,
120
+ aspect = model_version_info ,
121
+ changeType = models .ChangeTypeClass .UPSERT
122
+ )
123
+
124
+ yield MetadataChangeProposalWrapper (
125
+ entityUrn = str (model_urn ),
126
+ entityType = "mlModel" ,
127
+ aspectName = "mlModelProperties" ,
128
+ aspect = model_info ,
129
+ changeType = models .ChangeTypeClass .UPSERT
130
+ )
131
+
132
+
133
+ def generate_pipeline (
134
+ pipeline_name : str ,
135
+ orchestrator : str ,
136
+ ) -> Iterable [Union [models .MetadataChangeProposalClass , MetadataChangeProposalWrapper ]]:
137
+ data_flow = DataFlow (
138
+ id = pipeline_name ,
139
+ orchestrator = orchestrator ,
140
+ cluster = "default" ,
141
+ name = pipeline_name ,
142
+ )
143
+
144
+ data_job = DataJob (id = "training" , flow_urn = data_flow .urn , name = "Training" )
145
+
146
+ input_dataset = Dataset (
147
+ id = "airline_passengers" ,
148
+ name = "Airline Passengers" ,
149
+ description = "Monthly airline passenger data" ,
150
+ properties = {},
151
+ platform = "s3" ,
152
+ schema = None ,
153
+ )
154
+
155
+ if orchestrator == ORCHESTRATOR_MLFLOW :
156
+ experiment = Container (
157
+ key = ContainerKeyWithId (
158
+ platform = str (DataPlatformUrn .create_from_id ("mlflow" )),
159
+ id = "airline_forecast_experiment" ,
160
+ ),
161
+ subtype = "ML Experiment" ,
162
+ name = "Airline Forecast Experiment" ,
163
+ description = "Experiment for forecasting airline passengers" ,
164
+ )
165
+
166
+ yield from experiment .generate_mcp ()
167
+
168
+ model_group_urn = MlModelGroupUrn (platform = "mlflow" , name = "airline_forecast_models" )
169
+ current_time = int (time .time () * 1000 )
170
+ model_group_info = models .MLModelGroupPropertiesClass (
171
+ description = "ML models for airline passenger forecasting" ,
172
+ customProperties = {
173
+ "stage" : "production" ,
174
+ "team" : "data_science" ,
175
+ },
176
+ created = models .TimeStampClass (
177
+ time = current_time ,
178
+ actor = "urn:li:corpuser:datahub"
179
+ ),
180
+ lastModified = models .TimeStampClass (
181
+ time = current_time ,
182
+ actor = "urn:li:corpuser:datahub"
183
+ ),
184
+ )
185
+
186
+ yield MetadataChangeProposalWrapper (
187
+ entityUrn = str (model_group_urn ),
188
+ entityType = "mlModelGroup" ,
189
+ aspectName = "mlModelGroupProperties" ,
190
+ aspect = model_group_info ,
191
+ changeType = models .ChangeTypeClass .UPSERT
192
+ )
193
+
194
+ model_aliases = ["challenger" , "champion" , "production" , "experimental" , "deprecated" ]
195
+ model_tags = ["stage:production" , "stage:development" , "team:data_science" , "team:ml_engineering" ,
196
+ "team:analytics" ]
197
+
198
+ model_dict = {
199
+ "arima_model_1" : "ARIMA model for airline passenger forecasting" ,
200
+ "arima_model_2" : "Enhanced ARIMA model with seasonal components" ,
201
+ "arima_model_3" : "ARIMA model optimized for long-term forecasting" ,
202
+ "arima_model_4" : "ARIMA model with hyperparameter tuning" ,
203
+ "arima_model_5" : "ARIMA model trained on extended dataset" ,
204
+ }
205
+
206
+ # Generate run timestamps within the last month
207
+ end_time = int (time .time () * 1000 )
208
+ start_time = end_time - (30 * 24 * 60 * 60 * 1000 )
209
+ run_timestamps = [
210
+ start_time + (i * 5 * 24 * 60 * 60 * 1000 )
211
+ for i in range (5 )
212
+ ]
213
+
214
+ run_dict = {
215
+ "run_1" : {"start_time" : run_timestamps [0 ], "duration" : 45 , "result" : InstanceRunResult .SUCCESS },
216
+ "run_2" : {"start_time" : run_timestamps [1 ], "duration" : 60 , "result" : InstanceRunResult .FAILURE },
217
+ "run_3" : {"start_time" : run_timestamps [2 ], "duration" : 55 , "result" : InstanceRunResult .SUCCESS },
218
+ "run_4" : {"start_time" : run_timestamps [3 ], "duration" : 70 , "result" : InstanceRunResult .SUCCESS },
219
+ "run_5" : {"start_time" : run_timestamps [4 ], "duration" : 50 , "result" : InstanceRunResult .FAILURE },
220
+ }
221
+
222
+ for i , (model_name , model_description ) in enumerate (model_dict .items (), start = 1 ):
223
+ run_id = f"run_{ i } "
224
+ data_process_instance = DataProcessInstance .from_container (
225
+ container_key = experiment .key , id = run_id
226
+ )
227
+
228
+ data_process_instance .subtype = "Training Run"
229
+ data_process_instance .inlets = [DatasetUrn .from_string (input_dataset .urn )]
230
+
231
+ output_dataset = Dataset (
232
+ id = f"passenger_forecast_24_12_0{ i } " ,
233
+ name = f"Passenger Forecast 24_12_0{ i } " ,
234
+ description = f"Forecasted airline passenger numbers for run { i } " ,
235
+ properties = {},
236
+ platform = "s3" ,
237
+ schema = None ,
238
+ )
239
+ yield from output_dataset .generate_mcp ()
240
+
241
+ data_process_instance .outlets = [DatasetUrn .from_string (output_dataset .urn )]
242
+
243
+ # Training metrics and hyperparameters
244
+ training_metrics = [
245
+ models .MLMetricClass (
246
+ name = "accuracy" ,
247
+ value = str (random .uniform (0.7 , 0.99 )),
248
+ description = "Test accuracy"
249
+ ),
250
+ models .MLMetricClass (
251
+ name = "f1_score" ,
252
+ value = str (random .uniform (0.7 , 0.99 )),
253
+ description = "Test F1 score"
254
+ )
255
+ ]
256
+ hyper_params = [
257
+ models .MLHyperParamClass (
258
+ name = "n_estimators" ,
259
+ value = str (random .randint (50 , 200 )),
260
+ description = "Number of trees"
261
+ ),
262
+ models .MLHyperParamClass (
263
+ name = "max_depth" ,
264
+ value = str (random .randint (5 , 15 )),
265
+ description = "Maximum tree depth"
266
+ )
267
+ ]
268
+
269
+ # DPI properties
270
+ created_at = int (time .time () * 1000 )
271
+ dpi_props = models .DataProcessInstancePropertiesClass (
272
+ name = f"Training { run_id } " ,
273
+ created = models .AuditStampClass (time = created_at , actor = "urn:li:corpuser:datahub" ),
274
+ customProperties = {
275
+ "framework" : "statsmodels" ,
276
+ "python_version" : "3.8" ,
277
+ },
278
+ )
279
+
280
+ mlrun_props = models .MLTrainingRunPropertiesClass (
281
+ id = "run_id" ,
282
+ outputUrls = ["s3://mlflow/artifacts" ],
283
+ hyperParams = hyper_params ,
284
+ trainingMetrics = training_metrics ,
285
+ externalUrl = "http://mlflow:5000" ,
286
+ )
287
+
288
+ yield from data_process_instance .generate_mcp (
289
+ created_ts_millis = created_at , materialize_iolets = True
290
+ )
291
+
292
+ yield MetadataChangeProposalWrapper (
293
+ entityUrn = str (data_process_instance .urn ),
294
+ aspect = dpi_props ,
295
+ )
296
+
297
+ yield MetadataChangeProposalWrapper (
298
+ entityUrn = str (data_process_instance .urn ),
299
+ aspect = mlrun_props ,
300
+ )
301
+
302
+ # Generate start and end events
303
+ start_time_millis = run_dict [run_id ]["start_time" ]
304
+ duration_minutes = run_dict [run_id ]["duration" ]
305
+ end_time_millis = start_time_millis + duration_minutes * 60000
306
+ result = run_dict [run_id ]["result" ]
307
+ result_type = "SUCCESS" if result == InstanceRunResult .SUCCESS else "FAILURE"
308
+
309
+ yield from data_process_instance .start_event_mcp (
310
+ start_timestamp_millis = start_time_millis
311
+ )
312
+ yield from data_process_instance .end_event_mcp (
313
+ end_timestamp_millis = end_time_millis ,
314
+ result = result ,
315
+ result_type = result_type ,
316
+ start_timestamp_millis = start_time_millis ,
317
+ )
318
+
319
+ # Model
320
+ selected_aliases = random .sample (model_aliases , k = random .randint (1 , 2 ))
321
+ selected_tags = random .sample (model_tags , 2 )
322
+ yield from create_model (
323
+ model_name = model_name ,
324
+ model_group_urn = str (model_group_urn ),
325
+ data_process_instance_urn = str (data_process_instance .urn ),
326
+ tags = selected_tags ,
327
+ version_aliases = selected_aliases ,
328
+ index = i ,
329
+ training_metrics = training_metrics ,
330
+ hyper_params = hyper_params ,
331
+ model_description = model_description ,
332
+ created_at = created_at ,
333
+ )
334
+
335
+ if orchestrator == ORCHESTRATOR_AIRFLOW :
336
+ yield from data_flow .generate_mcp ()
337
+ yield from data_job .generate_mcp ()
338
+
339
+ yield from input_dataset .generate_mcp ()
340
+
341
+
342
+ if __name__ == "__main__" :
343
+ token = "eyJhbGciOiJIUzI1NiJ9.eyJhY3RvclR5cGUiOiJVU0VSIiwiYWN0b3JJZCI6ImRhdGFodWIiLCJ0eXBlIjoiUEVSU09OQUwiLCJ2ZXJzaW9uIjoiMiIsImp0aSI6Ijg3MWEyZjU2LTY2MjUtNGRiMC04OTZhLTAyMzBmNmM0MmRkZCIsInN1YiI6ImRhdGFodWIiLCJleHAiOjE3Mzk2ODcwMDIsImlzcyI6ImRhdGFodWItbWV0YWRhdGEtc2VydmljZSJ9.HDGaXw8iBTXIEqKyIQl-jSlS8BquAXZHELP4hA9thOM"
344
+ graph_config = DatahubClientConfig (
345
+ server = "http://localhost:8080" ,
346
+ token = token ,
347
+ extra_headers = {
348
+ "Authorization" : f"Bearer { token } " }
349
+ )
350
+ graph = DataHubGraph (graph_config )
351
+ with graph :
352
+ for mcp in generate_pipeline (
353
+ "airline_forecast_pipeline_mlflow" , orchestrator = ORCHESTRATOR_MLFLOW
354
+ ):
355
+ graph .emit (mcp )
356
+ for mcp in generate_pipeline (
357
+ "airline_forecast_pipeline_airflow" , orchestrator = ORCHESTRATOR_AIRFLOW
358
+ ):
359
+ graph .emit (mcp )
0 commit comments