KNN implementa la clasificación digital MNIST
import torch
from torch.utils.data import DataLoader
import torchvision.datasets as dsets
import numpy as np
import operator
batch_size = 100
data_path = './'
def KNN_distance(k, dis, trains_1, labels_2, test):
#曼哈顿距离(Manhattan distance): 简称M
#欧式距离(Euclidean Metric): 简称E
assert dis == 'M' or dis == 'E'
count = test.shape[0]
label_list = []
# 欧式距离 sqrt((x1-x2)^2 + (y1-y2)^2)
if dis == 'E':
for i in range(count):
distance = np.sqrt(np.sum(((trains_1 - np.tile(test[i], (trains_1.shape[0], 1))) ** 2), axis=1))
nearest_k = np.argsort(distance)
topK = nearest_k[:k]
classCount = {
}
for i in topK:
# print(i)
classCount[labels_2[i]] = classCount.get(labels_2[i], 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
label_list.append(sortedClassCount[0][0])
# 曼哈顿距离:|x1-x2| + |y1-y2|
elif dis == 'M':
for i in range(count):
distance = np.sum(np.abs(trains_1 - np.tile(test[i], (trains_1.shape[0], 1))), axis=1)
nearest_k = np.argsort(distance)
topK = nearest_k[:k]
classCount = {
}
for i in topK:
# print(i)
classCount[labels_2[i]] = classCount.get(labels_2[i], 0) + 1
sortedClassCount = sorted(classCount.items(), key=operator.itemgetter(1), reverse=True)
label_list.append(sortedClassCount[0][0])
return label_list
def GetXMean(img_data):
"""
calculate the image's mean and std value
:param img_data: images
:return: mean, std
"""
mean = []
std = []
for i, img in enumerate(img_data):
mean.append(img[:, :].mean())
std.append(img[:, :].std())
mean = (np.array(mean) + 0.5).astype(np.int32)
std = (np.array(std) + 0.5).astype(np.int32)
return mean, std
def Centralized(img_data, mean, std):
return img_data.astype(np.int32) - mean.reshape((mean.shape[0], 1, 1))
# download MNIST dataset
train_dataset = dsets.MNIST(root=data_path,
train=True,
transform=None,
download=True)
test_dataset = dsets.MNIST(root=data_path,
train=False,
transform=None,
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True) #数据打乱
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True) #数据打乱
# 训练集样本数和维度
print("train_data: ", train_dataset.train_data.size())
# 训练集标签的长度
print("train_data: ", train_dataset.train_labels.size())
# 测试集样本数和维度
print("test_data: ", test_dataset.train_data.size())
# 测试集标签的长度
print("test_data: ", test_dataset.train_labels.size())
if __name__ == '__main__':
x_train = train_loader.dataset.train_data.numpy()
x_mean, x_std = GetXMean(x_train)
x_train = Centralized(x_train, x_mean, x_std).reshape(x_train.shape[0], 28*28)
y_train = train_loader.dataset.train_labels.numpy()
x_test = test_loader.dataset.test_data[:1000].numpy()
x_mean, x_std = GetXMean(x_test)
x_test = Centralized(x_test, x_mean, x_std).reshape(x_test.shape[0], 28*28)
y_test = test_loader.dataset.test_labels[:1000].numpy()
num_test = y_test.shape[0]
y_test_pred = KNN_distance(5, 'E', x_train, y_train, x_test)
num_correct = np.sum(y_test_pred == y_test)
accuracy = float(num_correct) / num_test
print('Euclidean: Got %d / %d correct => accuracy: %f' % (num_correct, num_test, accuracy))
y_test_pred = KNN_distance(5, 'M', x_train, y_train, x_test)
num_correct = np.sum(y_test_pred == y_test)
accuracy = float(num_correct) / num_test
print('Manhattan: Got %d / %d correct => accuracy: %f' % (num_correct, num_test, accuracy))
"""
计算输出结果:
train_data: torch.Size([60000, 28, 28])
train_data: torch.Size([60000])
test_data: torch.Size([10000, 28, 28])
test_data: torch.Size([10000])
Euclidean: Got 964 / 1000 correct => accuracy: 0.964000
Manhattan: Got 942 / 1000 correct => accuracy: 0.942000
"""
Se puede ver en los resultados del cálculo anterior que usando la distancia Manhattan y la distancia euclidiana para implementar el algoritmo KNN respectivamente, los resultados obtenidos tienen un cierto error. La razón fundamental es que el método de cálculo de la distancia es diferente. Para más detalles, consulte el algoritmo de aprendizaje profundo-KNN ( algoritmo del vecino k-más cercano) .
Antes de realizar el algoritmo KNN, se realiza cierto preprocesamiento sobre la imagen, es decir, eliminar el valor medio, normalizar, etc., todo con el fin de mejorar la precisión de la inferencia.