1
+ import numpy as np
2
+ import matplotlib .pyplot as plt
3
+ import pandas as pd
4
+
5
+ def plot_decision_boundaries (X , y , model_class , ** model_params ):
6
+ """Function to plot the decision boundaries of a classification model.
7
+ This uses just the first two columns of the data for fitting
8
+ the model as we need to find the predicted value for every point in
9
+ scatter plot.
10
+
11
+ Arguments:
12
+ X: Feature data as a NumPy-type array.
13
+ y: Label data as a NumPy-type array.
14
+ model_class: A Scikit-learn ML estimator class
15
+ e.g. GaussianNB (imported from sklearn.naive_bayes) or
16
+ LogisticRegression (imported from sklearn.linear_model)
17
+ **model_params: Model parameters to be passed on to the ML estimator
18
+
19
+ Typical code example:
20
+ plt.figure()
21
+ plt.title("KNN decision boundary with neighbros: 5",fontsize=16)
22
+ plot_decision_boundaries(X_train,y_train,KNeighborsClassifier,n_neighbors=5)
23
+ plt.show()
24
+ """
25
+ try :
26
+ X = np .array (X )
27
+ y = np .array (y ).flatten ()
28
+ except :
29
+ print ("Coercing input data to NumPy arrays failed" )
30
+ reduced_data = X [:, :2 ]
31
+ model = model_class (** model_params )
32
+ model .fit (reduced_data , y )
33
+
34
+ # Step size of the mesh. Decrease to increase the quality of the VQ.
35
+ h = .02 # point in the mesh [x_min, m_max]x[y_min, y_max].
36
+
37
+ # Plot the decision boundary. For that, we will assign a color to each
38
+ x_min , x_max = reduced_data [:, 0 ].min () - 1 , reduced_data [:, 0 ].max () + 1
39
+ y_min , y_max = reduced_data [:, 1 ].min () - 1 , reduced_data [:, 1 ].max () + 1
40
+ xx , yy = np .meshgrid (np .arange (x_min , x_max , h ), np .arange (y_min , y_max , h ))
41
+
42
+ # Obtain labels for each point in mesh using the model.
43
+ Z = model .predict (np .c_ [xx .ravel (), yy .ravel ()])
44
+
45
+ x_min , x_max = X [:, 0 ].min () - 1 , X [:, 0 ].max () + 1
46
+ y_min , y_max = X [:, 1 ].min () - 1 , X [:, 1 ].max () + 1
47
+ xx , yy = np .meshgrid (np .arange (x_min , x_max , 0.1 ),
48
+ np .arange (y_min , y_max , 0.1 ))
49
+
50
+ Z = model .predict (np .c_ [xx .ravel (), yy .ravel ()]).reshape (xx .shape )
51
+
52
+ plt .contourf (xx , yy , Z , alpha = 0.4 )
53
+ plt .scatter (X [:, 0 ], X [:, 1 ], c = y , alpha = 0.8 )
54
+ plt .xlabel ("Feature-1" ,fontsize = 15 )
55
+ plt .ylabel ("Feature-2" ,fontsize = 15 )
56
+ plt .xticks (fontsize = 14 )
57
+ plt .yticks (fontsize = 14 )
58
+ return plt
0 commit comments