Implement kNN based on sklearn

An article before the Spring Festival introduced the kNN algorithm, which helps you understand the operating mechanism and facilitate parameter tuning. Mature algorithms naturally have ready-made modules that can be used.


The scikit-learn package is a machine learning component of Python, which implements a series of machine learning algorithms such as kNN, support vector machines, k-means clustering, etc.


Installing the scikit-learn package requires the following three commands:

pip/pip3 install numpy

pip/pip3 install scipy

pip/pip3 install sklearn


This article looks at how to implement kNN using the scikit-learn package. scikit-learn itself comes with many classic datasets, such as the iris dataset that is often used in classification.


Data set preparation

>>> from sklearn import datasets

>>> import numpy as np

>>> iris=datasets.load_iris()

>>> iris_x=iris.data
>>> iris_y=iris.target


The above code first imports the iris dataset, and then separates the feature data and labels of the iris dataset. We can look at the data first.

>>> np.shape(iris_x)
(150, 4)


>>> np.shape(iris_y)
(150,)


There are a total of 150 pieces of data, and there are four features. If you understand the iris data set, you should know that these four features are calyx length, calyx width, petal length, and petal width.

>>> np.unique(iris_y)
array([0, 1, 2])


The labels (that is, the types of data sets) include three kinds, namely Iris Setosa (mountain iris), Iris Versicolour (variegated iris), and Iris Virginica (Virginia iris).

>>> print(iris_y)
[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 2 2 2 2 2
 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]


It can be seen that the labels are arranged in the order of 0-2, such a data set is not conducive to our use, so we need to randomly scramble the data set, and then divide the training set and the test set.

>>> indices = np.random.permutation(len(iris_x))
>>> iris_x_train = iris_x[indices[:-10]]
>>> iris_y_train = iris_y[indices[:-10]]
>>> iris_x_test  = iris_x[indices[-10:]]
>>> iris_y_test  = iris_y[indices[-10:]]


We first randomly generate array subscripts from 0 to 149, and then based on the array subscripts, we select 140 pieces of data as the training set and 10 pieces of data as the test set.

>>> print(iris_y_train)
[2 0 1 1 1 0 2 2 0 0 2 2 2 1 1 1 0 2 1 1 2 2 2 1 2 0 1 1 0 0 2 2 1 0 1 0 2
 1 1 0 0 1 2 2 1 2 0 1 0 0 2 0 1 0 1 0 2 1 1 0 1 2 2 2 2 2 0 1 2 1 0 0 2 2
 1 0 1 2 1 0 0 2 2 1 1 2 2 1 1 2 2 0 2 1 0 2 0 0 0 2 1 0 1 0 0 0 0 0 2 1 1
 2 2 0 0 1 2 1 1 1 2 0 1 1 2 2 1 0 2 0 2 1 0 1 1 2 1 0 0 0]


This time, our classes are shuffled and more suitable for training. Your results may differ from mine above due to random selection. This way, the data is ready.


The use of kNN algorithm

>>> from sklearn.neighbors import KNeighborsClassifier

>>> knn = KNeighborsClassifier(n_neighbors=3)


Import KNeighborsClassifier, construct a kNN classifier, and pass in the parameter n_neighbors, which is our k value. The constructor has some other parameters, which can be introduced by referring to the official documentation.

>>> knn.fit(iris_x_train, iris_y_train)


For training knn, you should be able to think from the introduction of the algorithm in the previous article. The so-called training here is just for the unification of the interface of scikit-learn. The actual kNN does not need to train the model, so this interface only saves the sample set data.

>>> iris_y_predict = knn.predict(iris_x_test)

>>> print(iris_y_predict)
[0 2 0 0 0 0 2 1 1 2]


Make predictions, get the prediction results, and see if there are any prediction mismatches

>>> error_index = np.nonzero(iris_y_test - iris_y_predict)[0]
>>> print(error_index)
[2]


Find that the data with index 2 is wrongly predicted, and calculate the error rate

>>> len(error_index) / len(iris_y_test)
0.1


It can be seen that using scikit-learn we no longer need to write our own algorithms, and can solve problems more easily and conveniently. For a detailed introduction to scikit-learn's kNN, you can see the official documentation.


Guess you like

Origin http://43.154.161.224:23101/article/api/json?id=325580253&siteId=291194637