卷积神经网络分类mnist手写体数字

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import  input_data
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
import matplotlib.pyplot as plt

class Net:
    def __init__(self):
        self.x = tf.placeholder(tf.float32,[None,28,28,1])
        self.y = tf.placeholder(tf.float32,[None,10])
        self.conv1_w = tf.Variable(tf.random_normal([3,3,1,16],dtype=tf.float32,stddev=0.1))
        self.conv1_b = tf.Variable(tf.zeros([16]))
        self.conv2_w = tf.Variable(tf.random_normal([3,3,16,32],dtype=tf.float32,stddev=0.1))
        self.conv2_b = tf.Variable(tf.zeros([32]))
        self.w1 = tf.Variable(tf.random_normal([7*7*32,128],stddev=0.1))
        self.b1 = tf.Variable(tf.zeros([128]))
        self.w2 = tf.Variable(tf.random_normal([128,10],stddev=0.1))
        self.b2 = tf.Variable(tf.zeros([10]))
    def forward(self):
        self.conv1 = tf.nn.relu(tf.nn.conv2d(self.x,self.conv1_w,strides=[1,1,1,1],padding='SAME')+self.conv1_b)
        self.pool1 = tf.nn.max_pool(self.conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        self.conv2 = tf.nn.relu(tf.nn.conv2d(self.pool1,self.conv2_w,strides=[1,1,1,1],padding='SAME')+self.conv2_b)
        self.pool2 = tf.nn.max_pool(self.conv2,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME')
        self.flat = tf.reshape(self.pool2,[-1,7*7*32])
        self.y1 = tf.nn.relu(tf.matmul(self.flat,self.w1)+self.b1)
        self.y2 = tf.nn.softmax(tf.matmul(self.y1,self.w2)+self.b2)
    def backward(self):
        self.loss = tf.reduce_mean((self.y2-self.y)**2)
        self.opt = tf.train.AdamOptimizer().minimize(self.loss)
        self.prediction_corect = tf.equal(tf.argmax(self.y2,1),tf.argmax(self.y,1))#比较预测值和真实值是否相等
        self.rst = tf.cast(self.prediction_corect,'float')#将布尔值转化为float类型
        self.accuracy = tf.reduce_mean(self.rst)#求出平均值表示精度(百分数)

if __name__ == '__main__':
    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        a = []
        b = []
        c = []
        for i in range(1000):
            a.append(i)
            x,y = mnist.train.next_batch(100)
            x = x.reshape([100,28,28,1])
            loss,acc,_ = sess.run([net.loss,net.accuracy,net.opt],feed_dict={net.x:x,net.y:y})
            b.append(acc)
            c.append(loss)
            if i%10 == 0:
                plt.subplot(1,2,1)#生成1行两列的子图显示在第一个子图
                plt.plot(a,b)
                plt.title('accuracy rate')
                plt.subplot(1,2,2)#生成1行两列的子图显示在第二个子图
                plt.plot(a,c)
                plt.title('loss')
                plt.pause(0.0001)
            print(loss,acc)

猜你喜欢

转载自blog.csdn.net/weixin_38241876/article/details/85211675
今日推荐