神经网络实现手写数字识别(Octave转换为Python)

详细代码参考github

神经网络实现手写数字识别

实例:利用神经网络实现手写数字的识别,网络已经训练好,权重参数已给出。

1.载入数据和权重

由于给出了权重的参数,即神经网络已经训练好了,我们直接拿权重对现有的图片进行预测。

载入输入数据:
参考代码:

def loadData(self, path):
   	self.data = scio.loadmat(path)
   	self.x = self.data["X"]  # (5000, 400)  # 原100训练
   	self.y = self.data["y"]  # (5000, 1)
   	index = random.sample([i for i in range(5000)], 100)  # 随机100个没有重复的数字
   	self.pics = self.x[index, :]  # (100, 400)

载入权重参数
参考代码:

def loadWeights(self, path):
   	weights = scio.loadmat(path)
   	self.theta1 = weights['Theta1']  # 25*401
   	self.theta2 = weights['Theta2']  # 10*26
2.神经网络构建

神经网络共3层,输入层,1层隐藏层,输出层:输入层401个输入(第1个为1), 隐藏层26个单元,输出层10个单元(对应着0-9),如下图
在这里插入图片描述

3.对全部数据进行准确率验证

利用吴老师训练好的结果,进行验证,准确率97.52%。

参考代码:

def predictNN(self):
   	x = np.hstack([np.ones((self.x.shape[0], 1)), self.x])  # 5000*401
   	x1 = self.sigmoid(x.dot(self.theta1.T))  # (5000, 401)*(401, 25)

   	x1_mid = np.hstack([np.ones((x1.shape[0], 1)), x1])
   	x2 = self.sigmoid(x1_mid.dot(self.theta2.T))  # (5000, 26)*(26, 10)
   	position = np.argmax(x2, axis=1) + 1   # 预测值
   	accuracy = np.mean(position.reshape(5000, 1) == self.y) * 100
   	print("神经网络准确率是:{}".format(accuracy))  # 97.52%
4.随机抽出一张图片,对图片中的数字进行验证

从5000张图片中随机抽取一张,利用神经网络计算得到预测结果。将预测结果做成图片的title进行显示,关闭图片后,提示若继续验证,请按回车;退出请按q键,展示两次的预测结果,可以看到已准确识别!
预测7预测5
参考代码:

def predictOne(self, image):
   	x = np.hstack([np.ones((image.shape[0], 1)), image])  # 1*401
   	x1 = self.sigmoid(x.dot(self.theta1.T))  # (1, 401)*(401, 25)

   	x1_mid = np.hstack([np.ones((x1.shape[0], 1)), x1])
   	x2 = self.sigmoid(x1_mid.dot(self.theta2.T))  # (1, 26)*(26, 10)

   	position = np.argmax(x2, axis=1) + 1
   	return position
   	
def displayTestPics(self, image):
   	max_val = np.max(np.abs(image))
   	im = image.reshape((20, 20)).transpose()/max_val*255
   	predict_result = self.predictOne(image)
   	plt.xticks([])
   	plt.yticks([])
   	# 由于0用10表示,为了显示准确,取了余数.
   	plt.title("The Prediction Result is {}!".format(np.mod(predict_result[0], 10)), color='r', fontsize=20)
   	plt.imshow(im, cmap='gray')
   	plt.show()

猜你喜欢

转载自blog.csdn.net/u013617229/article/details/84866562
今日推荐