Tensorflow框架搭建全连接神经网络训练手写数字mnist数据集

本文将用Tensorflow框架训练Mnist数据集,搭建全连接神经网络,损失将以动态折线图方式展示
全连接神经网络如图所示:
在这里插入图片描述
Mnist数据集是0-9十个数字构成的图片形式的数据集,每张图片是28*28的大小在这里插入图片描述
在这里插入图片描述
导入tensorflow中带的mnist数据集,以one-hot的形式:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(".\MNIST_data",one_hot=True)

这里建立一个Net的类,self.x是神经网络数据的输入,用占位符tf.placeholder占位,输入的形状是[N,V]结构N是批次,V为28*28的=784的数据,整张图片不能直接传入神经网络,每张图片是28乘以28,要变成784乘以1,即把每个像素挨个排列送进网络。
这里用的是两层的神经网络,w是权重,截取自标准正态分布,b为偏置,设为0,因为是10分类问题,所以最后的输出有10个
感知机模型
在这里插入图片描述

class Net:
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32,shape=[None,784])
        self.y = tf.placeholder(dtype=tf.float32,shape=[None,10])
        self.w1= tf.Variable(tf.truncated_normal(shape=[784,256],stddev=0.01,dtype=tf.float32))
        self.b1= tf.Variable(tf.zeros(shape=[256],dtype=tf.float32))
        self.w2= tf.Variable(tf.truncated_normal(shape=[256,10],stddev=0.01,dtype=tf.float32))
        self.b2= tf.Variable(tf.zeros(shape=[10],dtype=tf.float32))

定义前向:
根据公式f(wx+b),f为激活函数,第一层的输出作为第二层的输入,第一层用rule激活,最后一层用softmax激活(多分类问题最后一层要用softmax激活)

    def forward(self):
        y1 = tf.nn.relu(tf.matmul(self.x,self.w1)+self.b1)
        self.y2 = tf.matmul(y1,self.w2)+self.b2
        self.output = tf.nn.softmax(self.y2)

定义损失函数:loss,使用交叉熵损失函数softmax_cross_entropy_with_logits,具体用法请自行学习

    def loss(self):
        self.error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y,logits=self.y2))

定义后向函数backward,使用Adam优化器优化损失,学习率为0.001

    def backward(self):
        self.optimizer = tf.train.GradientDescentOptimizer(0.001).minimize(self.error)

之后就是主函数,实例化,喂数据,训练、验证,并使用matplotlib将损失以动态折线图的形式展示出来,下面是全部程序

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(".\MNIST_data",one_hot=True)
# mnist = ".\MNIST_data"
import matplotlib.pyplot as plt

class Net:
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32,shape=[None,784])
        self.y = tf.placeholder(dtype=tf.float32,shape=[None,10])
        self.w1= tf.Variable(tf.truncated_normal(shape=[784,256],stddev=0.01,dtype=tf.float32))
        self.b1= tf.Variable(tf.zeros(shape=[256],dtype=tf.float32))
        self.w2= tf.Variable(tf.truncated_normal(shape=[256,10],stddev=0.01,dtype=tf.float32))
        self.b2= tf.Variable(tf.zeros(shape=[10],dtype=tf.float32))
    def forward(self):
        y1 = tf.nn.relu(tf.matmul(self.x,self.w1)+self.b1)
        self.y2 = tf.matmul(y1,self.w2)+self.b2
        self.output = tf.nn.softmax(self.y2)
    def loss(self):
        self.error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=self.y,logits=self.y2))
    def backward(self):
        self.optimizer = tf.train.GradientDescentOptimizer(0.001).minimize(self.error)
    def accuracy(self):
        y = tf.equal(tf.argmax(self.output,axis=1),tf.argmax(self.y,axis=1))
        self.acc = tf.reduce_mean(tf.cast(y,dtype=tf.float32))
if __name__ == '__main__':
    net= Net()
    net.forward()
    net.loss()
    net.backward()
    net.accuracy()
    init = tf.global_variables_initializer()
    plt.ion()
    a=[]
    b=[]
    c=[]
    with tf.Session() as sess:
        sess.run(init)
        for i in range(50000):
            xs,ys = mnist.train.next_batch(100)
            error,_ = sess.run([net.error,net.optimizer],feed_dict={net.x:xs,net.y:ys})
            if i%100 == 0:
                xss,yss = mnist.validation.next_batch(100)
                _error,_output,acc = sess.run([net.error,net.output,net.acc],feed_dict={net.x:xss,net.y:yss})
                label= np.argmax(yss[0])
                out = np.argmax(_output[0])
                print("error:",error)
                print("label:",label,"output:",out)
                print(acc)
                a.append(i)
                b.append(error)
                c.append(_error)
                plt.clf()
                train, = plt.plot(a,b,linewidth = 1,color = "red")
                validate, = plt.plot(a,c,linewidth = 1, color = "blue")
                plt.legend([train,validate],["train","validate"],loc= "right top",fontsize = 10)
                plt.pause(0.01)
    plt.ioff()

运行之前请确认mnist数据集是否已经加载进来,如果没有要自行下载mnist数据集并粘贴到这里
在这里插入图片描述

运行结果:这里不展示动态图,只截取了刚开始运行时的损失和训练一段时间之后的损失
刚开始训练的损失
在这里插入图片描述
训练一段时间之后的损失
在这里插入图片描述
结论:用全连接神经网络训练mnist数据集,可以得到较好的效果,不过全连接的计算量比较大,如果用来训练较为复杂的数据,运行速度比较慢,精度低,效果不好,所以后面会介绍卷积神经网络CNN。
如果转载或引用请注明来源!

发布了18 篇原创文章 · 获赞 2 · 访问量 350

猜你喜欢

转载自blog.csdn.net/weixin_44928646/article/details/104519603