3
3
import pandas as pd
4
4
5
5
def plot_decision_boundaries (X , y , model_class , ** model_params ):
6
- """Function to plot the decision boundaries of a classification model.
6
+ """
7
+ Function to plot the decision boundaries of a classification model.
7
8
This uses just the first two columns of the data for fitting
8
9
the model as we need to find the predicted value for every point in
9
10
scatter plot.
@@ -27,8 +28,11 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
27
28
y = np .array (y ).flatten ()
28
29
except :
29
30
print ("Coercing input data to NumPy arrays failed" )
31
+ # Reduces to the first two columns of data
30
32
reduced_data = X [:, :2 ]
33
+ # Instantiate the model object
31
34
model = model_class (** model_params )
35
+ # Fits the model with the reduced data
32
36
model .fit (reduced_data , y )
33
37
34
38
# Step size of the mesh. Decrease to increase the quality of the VQ.
@@ -37,6 +41,7 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
37
41
# Plot the decision boundary. For that, we will assign a color to each
38
42
x_min , x_max = reduced_data [:, 0 ].min () - 1 , reduced_data [:, 0 ].max () + 1
39
43
y_min , y_max = reduced_data [:, 1 ].min () - 1 , reduced_data [:, 1 ].max () + 1
44
+ # Meshgrid creation
40
45
xx , yy = np .meshgrid (np .arange (x_min , x_max , h ), np .arange (y_min , y_max , h ))
41
46
42
47
# Obtain labels for each point in mesh using the model.
@@ -47,8 +52,10 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
47
52
xx , yy = np .meshgrid (np .arange (x_min , x_max , 0.1 ),
48
53
np .arange (y_min , y_max , 0.1 ))
49
54
55
+ # Predictions to obtain the classification results
50
56
Z = model .predict (np .c_ [xx .ravel (), yy .ravel ()]).reshape (xx .shape )
51
57
58
+ # Plotting
52
59
plt .contourf (xx , yy , Z , alpha = 0.4 )
53
60
plt .scatter (X [:, 0 ], X [:, 1 ], c = y , alpha = 0.8 )
54
61
plt .xlabel ("Feature-1" ,fontsize = 15 )
0 commit comments