使用MINST数据集训练模型并动态绘制Loss和准确率的趋势图

MINST训练的代码直接参考http://www.pianshen.com/article/3783475094/

由于显示的结果不太清晰明了,因此使用pyplot加工了下:代码如下

#coding: utf-8
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("MNIST_data", one_hot=True)

input_x = tf.placeholder(tf.float32, [None, 28*28])/255  
output_y = tf.placeholder(tf.int32, [None, 10])   
input_x_images = tf.reshape(input_x, [-1, 28, 28, 1])

test_x = mnist.test.images[:3000]  
test_y = mnist.test.labels[:3000]  

conv1 = tf.layers.conv2d(inputs=input_x_images, 
                         filters=32,
                         kernel_size=[5, 5],
                         strides=1,
                         padding='same',
                         activation=tf.nn.relu
                         )
pool1 = tf.layers.max_pooling2d(
    inputs=conv1,
    pool_size=[2,2],
    strides=2
)

conv2 = tf.layers.conv2d(inputs=pool1,
                         filters=64,
                         kernel_size=[5, 5],
                         strides=1,
                         padding='same',
                         activation=tf.nn.relu
                         )

pool2 = tf.layers.max_pooling2d(
    inputs=conv2,
    pool_size=[2,2],
    strides=2
)

flat = tf.reshape(pool2, [-1, 7 * 7 * 64])

dense = tf.layers.dense(inputs=flat, units=1024, activation=tf.nn.relu)

dropout = tf.layers.dropout(inputs=dense, rate=0.5)

logits = tf.layers.dense(inputs=dropout, units=10)

loss = tf.losses.softmax_cross_entropy(onehot_labels=output_y, logits=logits)

train_op = tf.train.AdamOptimizer(learning_rate=0.005).minimize(loss)

accuracy = tf.metrics.accuracy(
    labels=tf.argmax(output_y, axis=1),
    predictions=tf.argmax(logits, axis=1),)[1]

with tf.Session() as sess:
    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    sess.run(init)

#动态生成图代码
    x= []
    x2=[]
    y1=[]
    y2=[]
#原代码
    for i in range(20000):
        batch = mnist.train.next_batch(50)
        train_loss, train_op_ = sess.run([loss, train_op], {input_x: batch[0], output_y: batch[1]})
        if i % 100 == 0:
            test_accuracy = sess.run(accuracy, {input_x: test_x, output_y: test_y})
            print("Step=%d, Train loss=%.4f, [Test accuracy=%.2f]" % (i, train_loss, test_accuracy))

#动态生成图代码
            plt.cla()
            x.append(i)
            y1.append(train_loss)            
            y2.append(test_accuracy)           
            plt.plot(x, y1, label='Loss rate', color='g')
            plt.plot(x, y2, label='Accuracy rate', color='b')  
            plt.xlabel('Step')
            plt.ylabel('Rate')
            plt.title('Train table')
            plt.legend()   
            plt.pause(1.0)    

结果直接上图

最终的趋势结果 

发布了2 篇原创文章 · 获赞 0 · 访问量 42

猜你喜欢

转载自blog.csdn.net/HUAWEIZHIZHE/article/details/104123999