Skip to content

Commit 203a05b

Browse files
authored
First upload with plot_decision_boundary function
1 parent 7a5cc2b commit 203a05b

File tree

1 file changed

+58
-0
lines changed

1 file changed

+58
-0
lines changed

Utilities/ML-Python-utils.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import pandas as pd
4+
5+
def plot_decision_boundaries(X, y, model_class, **model_params):
6+
"""Function to plot the decision boundaries of a classification model.
7+
This uses just the first two columns of the data for fitting
8+
the model as we need to find the predicted value for every point in
9+
scatter plot.
10+
11+
Arguments:
12+
X: Feature data as a NumPy-type array.
13+
y: Label data as a NumPy-type array.
14+
model_class: A Scikit-learn ML estimator class
15+
e.g. GaussianNB (imported from sklearn.naive_bayes) or
16+
LogisticRegression (imported from sklearn.linear_model)
17+
**model_params: Model parameters to be passed on to the ML estimator
18+
19+
Typical code example:
20+
plt.figure()
21+
plt.title("KNN decision boundary with neighbros: 5",fontsize=16)
22+
plot_decision_boundaries(X_train,y_train,KNeighborsClassifier,n_neighbors=5)
23+
plt.show()
24+
"""
25+
try:
26+
X = np.array(X)
27+
y = np.array(y).flatten()
28+
except:
29+
print("Coercing input data to NumPy arrays failed")
30+
reduced_data = X[:, :2]
31+
model = model_class(**model_params)
32+
model.fit(reduced_data, y)
33+
34+
# Step size of the mesh. Decrease to increase the quality of the VQ.
35+
h = .02 # point in the mesh [x_min, m_max]x[y_min, y_max].
36+
37+
# Plot the decision boundary. For that, we will assign a color to each
38+
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
39+
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
40+
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
41+
42+
# Obtain labels for each point in mesh using the model.
43+
Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
44+
45+
x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
46+
y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
47+
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
48+
np.arange(y_min, y_max, 0.1))
49+
50+
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
51+
52+
plt.contourf(xx, yy, Z, alpha=0.4)
53+
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
54+
plt.xlabel("Feature-1",fontsize=15)
55+
plt.ylabel("Feature-2",fontsize=15)
56+
plt.xticks(fontsize=14)
57+
plt.yticks(fontsize=14)
58+
return plt

0 commit comments

Comments
 (0)