Skip to content

Commit 16e7d50

Browse files
authored
Added multiple comments
1 parent 203a05b commit 16e7d50

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

Utilities/ML-Python-utils.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pandas as pd
44

55
def plot_decision_boundaries(X, y, model_class, **model_params):
6-
"""Function to plot the decision boundaries of a classification model.
6+
"""
7+
Function to plot the decision boundaries of a classification model.
78
This uses just the first two columns of the data for fitting
89
the model as we need to find the predicted value for every point in
910
scatter plot.
@@ -27,8 +28,11 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
2728
y = np.array(y).flatten()
2829
except:
2930
print("Coercing input data to NumPy arrays failed")
31+
# Reduces to the first two columns of data
3032
reduced_data = X[:, :2]
33+
# Instantiate the model object
3134
model = model_class(**model_params)
35+
# Fits the model with the reduced data
3236
model.fit(reduced_data, y)
3337

3438
# Step size of the mesh. Decrease to increase the quality of the VQ.
@@ -37,6 +41,7 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
3741
# Plot the decision boundary. For that, we will assign a color to each
3842
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
3943
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
44+
# Meshgrid creation
4045
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
4146

4247
# Obtain labels for each point in mesh using the model.
@@ -47,8 +52,10 @@ def plot_decision_boundaries(X, y, model_class, **model_params):
4752
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.1),
4853
np.arange(y_min, y_max, 0.1))
4954

55+
# Predictions to obtain the classification results
5056
Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape)
5157

58+
# Plotting
5259
plt.contourf(xx, yy, Z, alpha=0.4)
5360
plt.scatter(X[:, 0], X[:, 1], c=y, alpha=0.8)
5461
plt.xlabel("Feature-1",fontsize=15)

0 commit comments

Comments
 (0)