TensorFlow mnist 数据集练习

# -*- coding: UTF-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('data/',one_hot=True)
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testimglabel = mnist.test.labels
# print(testimg.shape)
# print(trainlabel.shape)
batch_size = 2**8
#分批次处理
# batch_x,batch_y = mnist.train.next_batch(batch_size)
# print(batch_x.shape)
# print(batch_y.shape)
#None 表示无穷 placehoder 只占位 不占空间
x = tf.placeholder('float', [None, 784])
y = tf.placeholder('float', [None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
#softmax 回归 分配概率
#http://www.tensorfly.cn/tfdoc/tutorials/mnist_beginners.html
#tf.matmul ->  https://www.jianshu.com/p/19ea2d15eb14
actv = tf.nn.softmax(tf.matmul(x,W)+b)
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(actv),reduction_indices=1))
learn_rate = 0.01
#梯度下降优化器
optm = tf.train.GradientDescentOptimizer(learn_rate).minimize(cost)
#tf.argmax -> http://blog.csdn.net/qq575379110/article/details/70538051
# 1 : 行 0: 列
pred = tf.equal(tf.argmax(actv,1),tf.argmax(y,1))
#http://blog.csdn.net/luoganttcc/article/details/70315538 数据转换
accr = tf.reduce_mean(tf.case(pred, "float"))
init = tf.global_variables_initializer()
training_ecpchs = 50
batch_size = 100
display = 5

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(training_ecpchs):
        avg_cost = 0
        num_batch = int(mnist.train.num_examples/batch_size)
        for i in range(num_batch):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run(optm,feed_dict={x:batch_xs,y:batch_ys})
            feeds ={x:batch_xs,y:batch_ys}
            avg_cost +=sess.run(cost,feed_dict=feeds)/num_batch
        if epoch % display == 0 :
            feeds_train ={x:batch_xs,y:batch_ys}
            feeds_test ={x:mnist.test.images,y:mnist.test.labels}
            train_acc = sess.run(accr,feed_dict=feeds_train)
            test_acc = sess.run(accr,feed_dict=feeds_test)
            print(epoch,training_ecpchs,avg_cost,train_acc,test_acc)


发布了56 篇原创文章 · 获赞 3 · 访问量 5万+

猜你喜欢

转载自blog.csdn.net/haoxuezhe1988/article/details/79247976
今日推荐