K-nn(k邻近学习)

knn算法:

第一步:将每个例子做成一个点,他所对应发特征向量是一个多维的坐标

第二步:自己选择一个参数k
第三步:计算未知实例与所有已知实例的欧式距离(也可以是其他距离),将他们排序。
第四步:选择根据排序好的k去选择k个已知实例
第五步:在这k个实例中,根据少数服从多数的原则,让未知实例归类到为最多数的类别
程序如下:
import numpy as  np
from sklearn import datasets
#求出每个测试集与训练集的欧氏距离,经过从小到大的排序,取前k个的,然后经过投票,得到预测值。
def k_nn(train_x, train_y, test_data, k):
    m_train = np.shape(train_x)[0] #5000*784
    predict = []
    k_index={}
    for i_train in range(m_train):
        k_index[i_train]=cal_k_index(train_x[i_train,:],test_data)
    k_result = sorted(k_index.items(), key=lambda d: d[1])
    result = []
    count = 0
    for x in k_result:
        if count >= k:
            break
        result.append(x[0])
        count += 1
    predict.append(get_prediction(train_y, result))# 得到预测
    return predict
#根据少数服从多数的原则,得到测试集的预测值。
def get_prediction(train_y,result):
    result_dict = {}
    for i in range(len(result)):
        if train_y[result[i]] not in result_dict:
            result_dict[train_y[result[i]]] = 1
        else:
            result_dict[train_y[result[i]]] += 1
    predict = sorted(result_dict.items(), key=lambda d: d[1])
    return predict[0][0]
#计算欧式距离
def cal_k_index(train_x,test_date):
    dist = np.sqrt(np.sum(np.square( train_x-test_date)))
    return dist
#计算预测值的精度
def get_correct_rate(result, test_y):
    m = len(result)
    correct=0
    for i in range(m):
        if result[i] == test_y[i]:
            correct += 1
    return correct / m
def main():
    # train_x, train_y, test_x, test_y = load_file()#当使用读取文件时使用
    #从机器学习sklearn库导出iris库 iris是一个150*4的库
    iris=datasets.load_iris()
    sourcedata=iris['data']  #得到属性向量
    labledata=iris['target'] #得到类别标记
    #划分训练集和测试集
    train_x=sourcedata[0:139,:]
    train_y=labledata[0:139]
    test_x=sourcedata[140:150,:]
    test_y=labledata[140:150]
    global result
    result={}
    for i in range(10):
        result[i] = k_nn(train_x, train_y,test_x[i,:],10)
    correct=get_correct_rate(result, test_y)
    print(result,correct)
if __name__=='__main__':
    main()
如果是kpl文件就加入以下程序:
 
  
读取压缩的pkl格式的文件
def  load_file():
    with gzip.open("./mnist.pkl.gz",'rb') as fp:
        training_data, valid_data, test_data = pickle.load(fp,encoding='iso-8859-1')
    return training_data[0], training_data[1], test_data[0],test_data[1]





猜你喜欢

转载自blog.csdn.net/qq_23859701/article/details/78955236