基于Python神经网络的手写字体识别

       本文将分享实现手写字体识别的神经网络实现,代码中有详细注释以及我自己的一些体会,希望能帮助到大家 (≧∇≦)/ 

##############################################手写字体识别#############################################
import numpy
import scipy.special
######################################定义神经网络的类#########################################
class neuralNetwork:
    # 初始化
    def __init__(self, inputnodes, hiddennodes, outputnodes, learningrate):
        # 设置输入层、隐藏层、输出层节点个数
        self.inodes = inputnodes
        self.hnodes = hiddennodes
        self.onodes = outputnodes

        self.wih = numpy.random.normal(0.0, pow(self.inodes, -0.5), (self.hnodes, self.inodes))
        self.who = numpy.random.normal(0.0, pow(self.hnodes, -0.5), (self.onodes, self.hnodes))

        # 学习率
        self.lr = learningrate

        # 激活函数
        self.activation_function = lambda x: scipy.special.expit(x)

        pass

    # 训练
    def train(self, inputs_list, targets_list):
        # 转化为二维数组
        inputs = numpy.array(inputs_list, ndmin=2).T
        targets = numpy.array(targets_list, ndmin=2).T

        # 计算隐藏层输入
        hidden_inputs = numpy.dot(self.wih, inputs)
        # 计算隐藏层输出
        hidden_outputs = self.activation_function(hidden_inputs)

        # 计算输出层的输入
        final_inputs = numpy.dot(self.who, hidden_outputs)
        # 计算输出层的输出
        final_outputs = self.activation_function(final_inputs)

        # 计算偏差
        output_errors = targets - final_outputs
       
        hidden_errors = numpy.dot(self.who.T, output_errors)

        # 更新权重
        self.who += self.lr * numpy.dot((output_errors * final_outputs * (1.0 - final_outputs)),
                                        numpy.transpose(hidden_outputs))

        self.wih += self.lr * numpy.dot((hidden_errors * hidden_outputs * (1.0 - hidden_outputs)),
                                        numpy.transpose(inputs))

        pass

    # 验证
    def query(self, inputs_list):
        # 转化为二维数组
        inputs = numpy.array(inputs_list, ndmin=2).T

        # 计算隐藏层输入
        hidden_inputs = numpy.dot(self.wih, inputs)
        # 计算隐藏层输出
        hidden_outputs = self.activation_function(hidden_inputs)

        # 计算输出层的输入
        final_inputs = numpy.dot(self.who, hidden_outputs)
        # 计算输出层的输出
        final_outputs = self.activation_function(final_inputs)

        return final_outputs

############################################init###############################################
# 输入、输出、隐藏层节点个数
input_nodes = 784
hidden_nodes = 500
output_nodes = 10
# 学习率
learning_rate = 0.1
# 训练轮次
epochs=5
#####################################创建一个神经网络############################################
n = neuralNetwork(input_nodes, hidden_nodes, output_nodes, learning_rate)
##########################################训练#################################################
train_file = open("mnist_dataset/MNIST_csv/mnist_train.csv",'r')
train_list = train_file.readlines()
train_file.close()

for e in range(epochs):
    for record in train_list:
        value_train = record.split(',')
        inputs = (numpy.asfarray(value_train[1:])/255.0*0.99)+0.01 # 数据预处理,把像素值范围限定到[0.01,1]

        targets=numpy.zeros(output_nodes)+0.01
        targets[int(value_train[0])] = 0.99
        #创建训练时目标序列(除了目标是0.99,其余9个都是0.01);而根据数据集特征 all_value[0]是每一组数的目标标签

        n.train(inputs,targets) #训练
    print("第 %d 代训练完成" % e)
##########################################测试#################################################
data_file = open("mnist_dataset/MNIST_csv/mnist_test.csv",'r')
test_data_list = data_file.readlines()
data_file.close()

#测试,并统计准确率
scorecard=[]
for record in test_data_list:
    value_test = record.split(',')
    correct_label=int(value_test[0])
    # print(correct_label,"正确答案")
    inputs = (numpy.asfarray(value_test[1:]) / 255.0 * 0.99) + 0.01
    outputs = n.query(inputs)
    label = numpy.argmax(outputs)
    # print(label,"网络给出的答案")
    if (label==correct_label):
        scorecard.append(1)
    else:
        scorecard.append(0)
#print(scorecard)

#打分
scorecard_array=numpy.asarray(scorecard)
print("performance = ",scorecard_array.sum()/scorecard_array.size)

       训练集和数据集的.csv格式文件可以从下面链接中获取:

链接:https://pan.baidu.com/s/1Wc55qHUQPbPBNz4uYdY8nA 
提取码:2pq6 
--来自百度网盘超级会员V3的分享

求学路上,你我共勉(๑•̀ㅂ•́)و✧

猜你喜欢

转载自blog.csdn.net/Albert_yeager/article/details/129886768