1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib .pyplot as plt
5
+ import joblib
6
+ from urllib .parse import urlparse
7
+ from sklearn .metrics import accuracy_score , precision_score , recall_score , confusion_matrix , roc_curve , classification_report
8
+ from sklearn .metrics import ConfusionMatrixDisplay
9
+ import mlflow
10
+ import mlflow .sklearn
11
+ from mlProject .entity .config_entity import ModelEvaluationConfig
12
+ from mlProject .utils .common import save_json
13
+ from pathlib import Path
14
+
15
+
16
+ class ModelEvaluation :
17
+ def __init__ (self , config : ModelEvaluationConfig ):
18
+ self .config = config
19
+
20
+
21
+ def eval_metrics (self , actual , pred ):
22
+ acc = accuracy_score (actual , pred )
23
+ prec = precision_score (actual , pred )
24
+ rec = recall_score (actual , pred )
25
+ cm = confusion_matrix (actual , pred )
26
+ cm_nor = confusion_matrix (actual , pred , normalize = 'true' )
27
+ cr = classification_report (actual , pred )
28
+ return acc , prec , rec , cm , cm_nor , cr
29
+
30
+
31
+
32
+ def log_into_mlflow (self ):
33
+
34
+ test_data = pd .read_csv (self .config .test_data_path )
35
+ model = joblib .load (self .config .model_path )
36
+
37
+ X_test = test_data .drop ([self .config .target_column ], axis = 1 )
38
+ y_test = test_data [[self .config .target_column ]]
39
+
40
+
41
+ mlflow .set_registry_uri (self .config .mlflow_uri )
42
+ tracking_url_type_store = urlparse (mlflow .get_tracking_uri ()).scheme
43
+
44
+
45
+ with mlflow .start_run ():
46
+
47
+ predicted_qualities = model .predict (X_test )
48
+
49
+ (acc , prec , rec , cm , cm_nor , cr ) = self .eval_metrics (y_test , predicted_qualities )
50
+
51
+ # Saving metrics as local
52
+ #scores = {"Accuracy": acc, "Precision": prec, "Recall": rec, "Confusion Mat": cm, "C_report": cr}
53
+ scores = {"Accuracy" : acc , "Precision" : prec , "Recall" : rec , "Confusion Mat" : np .array (cm ).tolist ()}
54
+ save_json (path = Path (self .config .metric_file_name ), data = scores )
55
+
56
+ mlflow .log_params (self .config .all_params )
57
+
58
+ mlflow .log_metric ("Accuracy" , acc )
59
+ mlflow .log_metric ("Precision" , prec )
60
+ mlflow .log_metric ("Recall" , rec )
61
+ #mlflow.log_metric("Classification report", cr)
62
+
63
+ mlflow .log_dict (np .array (cm ).tolist (), "confusion_matrix.json" )
64
+
65
+ disp = ConfusionMatrixDisplay (confusion_matrix = cm )
66
+ disp .plot (cmap = plt .cm .Blues , xticks_rotation = 45 )
67
+ plt .savefig ("ConfusionMatrix.png" )
68
+ mlflow .log_artifact ("ConfusionMatrix.png" )
69
+ plt .close ()
70
+
71
+ disp = ConfusionMatrixDisplay (confusion_matrix = cm_nor )
72
+ disp .plot (cmap = plt .cm .Blues , xticks_rotation = 45 )
73
+ plt .savefig ("NormalizedConfusionMatrix.png" )
74
+ mlflow .log_artifact ("NormalizedConfusionMatrix.png" )
75
+ plt .close ()
76
+
77
+
78
+ # Model registry does not work with file store
79
+ if tracking_url_type_store != "file" :
80
+
81
+ # Register the model
82
+ # There are other ways to use the Model Registry, which depends on the use case,
83
+ # please refer to the doc for more information:
84
+ # https://mlflow.org/docs/latest/model-registry.html#api-workflow
85
+
86
+ mlflow .sklearn .log_model (model , "model" , registered_model_name = "RandomForestClassifier" )
87
+ else :
88
+ mlflow .sklearn .log_model (model , "model" )
89
+
90
+
0 commit comments