cs231n lecture2 image classification

图片分类若采用最近邻法:

 1 import numpy as np
 2 
 3 class NearestNeighbor:
 4     def _init_(self):
 5         pass
 6   
 7     def train(self, X, y):
 8         self.Xtr = X
 9         self.ytr = y
10   
11     def predict(self, X):
12         num_test = X.shape[0]
13         Y_pred = np.zeros(num_test, dtype = self.ytr.dtype)
14     
15         for i in range(num_test):
16             distances = np.sum(np.abs(self.Xtr - X[i, :]), axis = 1)
17             min_index = np.argmin(distances)
18             Y_pred[i] = self.ytr[min_index]
19     
20     return Y_pred

train函数时间复杂度为O(1),test函数时间复杂度为O(n),n为训练集大小。

训练时间短,但测试时间过长。

一般最好是训练时间较长,测试时间短,如CNN。

猜你喜欢

转载自www.cnblogs.com/lxc1910/p/11300218.html