1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
| def plot_decision_boundary(model, axis): x0, x1 = np.meshgrid( np.linspace(axis[0], axis[1], int((axis[1]-axis[0])*100)).reshape(-1,1), np.linspace(axis[2], axis[3], int((axis[3]-axis[2])*100)).reshape(-1,1) ) X_new = np.c_[x0.ravel(), x1.ravel()] y_predict = model.predict(X_new) zz = y_predict.reshape(x0.shape) from matplotlib.colors import ListedColormap custom_cmap = ListedColormap(['#EF9A9A','#FFF59D','#90CAF9']) plt.contourf(x0, x1, zz, linewidth=5, cmap=custom_cmap)
knn_clf_all = KNeighborsClassifier() knn_clf_all.fit(iris.data[:,:2], iris.target)
metric_params=None, n_jobs=1, n_neighbors=5, p=2, weights='uniform')
plot_decision_boundary(knn_clf_all, axis=[4, 8, 1.5, 4.5]) plt.scatter(iris.data[iris.target==0,0], iris.data[iris.target==0,1]) plt.scatter(iris.data[iris.target==1,0], iris.data[iris.target==1,1]) plt.scatter(iris.data[iris.target==2,0], iris.data[iris.target==2,1]) plt.show()
|