cs231n nn分类

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/sinat_34233802/article/details/70160344
#python3
import numpy as np
def unpickle(file):       #数据集的python3 实例
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def load_CIFAR10(file):   
#get the training data  因为是一bytes编码的,需要在标签前面加b,提取数据
    dataTrain = []
    labelTrain = []
    for i in range(1,6):
        dic = unpickle(file+"\\data_batch_"+str(i))
        for item in dic[b"data"]:    
            dataTrain.append(item)
        for item in dic[b"labels"]:
            labelTrain.append(item)

#get test data
    dataTest = []
    labelTest = []
    dic = unpickle(file+"\\test_batch")
    for item in dic[b"data"]:
       dataTest.append(item)
    for item in dic[b"labels"]:
       labelTest.append(item)
    return dataTrain,labelTrain,dataTest,labelTest

Xtr, Ytr, Xte, Yte = load_CIFAR10('tedata/cifar-10-batches-py')
Xtr = np.asarray(Xtr)
Xte = np.asarray(Xte)
Ytr = np.asarray(Ytr)
Yte = np.asarray(Yte)
#Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072  #两种方式选一种
#Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072

class NearestNeighbor(object):
    def __init__(self):
        pass
    def train(self, X,y):
        self.xtr = X
        self.ytr = y

    def predict(self, X):
        num_test = X.shape[0]
    # lets make sure that the output type matches the input type
        Ypred = np.zeros(num_test, dtype = self.ytr.dtype)

    # loop over all test rows
        for i in range(num_test):
          distances = np.sqrt(np.sum(np.square(self.xtr - X[i,:]), axis = 1))
          min_index = np.argmin(distances) # get the index with smallest distance
          Ypred[i] = self.ytr[min_index] # predict the label of the nearest example
        return Ypred

nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(Xtr, Ytr) # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print ('accuracy: %f' % ( np.mean(Yte_predict == Yte) ))

这个代码跑了比较费时,我跑了一个小时才出来结果,主要原因是在predict过程中计算10000条数据费时。

猜你喜欢

转载自blog.csdn.net/sinat_34233802/article/details/70160344