diff --git a/train.py b/train.py index c08584f..e3346c7 100644 --- a/train.py +++ b/train.py @@ -2,7 +2,7 @@ import os from sklearn.ensemble import RandomForestClassifier -from sklearn.metrics import plot_confusion_matrix +from sklearn.metrics import ConfusionMatrixDisplay import matplotlib.pyplot as plt import numpy as np @@ -29,5 +29,5 @@ outfile.write(metrics) # Plot it -disp = plot_confusion_matrix(clf, X_test, y_test, normalize="true", cmap=plt.cm.Blues) +disp = ConfusionMatrixDisplay.from_estimator(clf, X_test, y_test, normalize="true", cmap=plt.cm.Blues) plt.savefig("plot.png")