mnist-手写数据集识别(tensorflow)

Mnist数据集官网:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges 

MNIST数据集一共有7万张图片,其中6万张是训练集,1万张是测试集。每张图片是28 × 28 的0 − 9的手写数字图片组成。每个图片是黑底白字的形式,黑底用0表示,白字用0-1之间的浮点数表示,越接近1,颜色越白。

训练数据集:train-images-idx3-ubyte.gz (9.45 MB,包含60,000个样本)。
训练数据集标签:train-labels-idx1-ubyte.gz(28.2 KB,包含60,000个标签)。
测试数据集:t10k-images-idx3-ubyte.gz(1.57 MB ,包含10,000个样本)。
测试数据集标签:t10k-labels-idx1-ubyte.gz(4.43 KB,包含10,000个样本的标签)。

训练集和测试集为形如:[0,0,0,0.435,0.754,0.760,0.320,0,0,⋯,0,0,0] 的 1x784的向量。

标签为形如:[0,0,0,0,0,1,0,0,0,0] 的 1x10的one-hot编码。

import tensorflow.compat.v1 as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

tf.compat.v1.disable_eager_execution()

learning_rate = 1e-4
keep_prob_rate = 0.7 
max_epoch = 2000
def compute_accuracy(v_xs, v_ys):
    global prediction
    y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})
    correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})
    return result

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W):
    # 每一维度  滑动步长全部是 1, padding 方式 选择 same
    # 提示 使用函数  tf.nn.conv2d
    
    return  tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
    # x = [100,28,28,1] w = [7,7,1,32]  [100,22,22,32]

def max_pool_2x2(x):
    # 滑动步长 是 2步; 池化窗口的尺度 高和宽度都是2; padding 方式 请选择 same
    # 提示 使用函数  tf.nn.max_pool
    # filter = [filter_height, filter_width, in_channels, out_channels]
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

# define placeholder for inputs to network
# 训练数据(55000, 784)
xs = tf.placeholder(tf.float32, [None, 784])/255.
ys = tf.placeholder(tf.float32, [None, 10])
keep_prob = tf.placeholder(tf.float32)
x_image = tf.reshape(xs, [-1, 28, 28, 1])

#  卷积层 1
## conv1 layer ##

W_conv1 = weight_variable([7,7, 1,32])                     # patch 7x7, in size 1, out size 32
b_conv1 = bias_variable([32])                     
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)                    # 卷积  自己选择 选择激活函数 [100,28,28,32]
h_pool1 = max_pool_2x2(h_conv1)                      # 池化  [100,14,14,32]              
#print(h_pool1.shape)                                  #
# 卷积层 2
W_conv2 = weight_variable([5,5, 32, 64])                    # patch 5x5, in size 32, out size 64
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)                      # 卷积  自己选择 选择激活函数 [100,14,14,64]
h_pool2 = max_pool_2x2(h_conv2)                      # 池化 [100,7,7,64]
#print(h_pool2.shape)

#  全连接层 1

## fc1 layer ##
W_fc1 = weight_variable([7*7*64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) #[100,4,1024]
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

# 全连接层 2
## fc2 layer ##
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
prediction = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
#print(prediction)


# 交叉熵函数  reduce_sum() 方法针对每一个维度进行求和,reduction_indices 是指定沿哪些维度进行求和。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
                                              reduction_indices=[1]))
# 优化损失函数,
train_step = tf.train.AdamOptimizer(learning_rate).minimize(cross_entropy)

with tf.Session() as sess:
    init = tf.global_variables_initializer()
    sess.run(init)
    
    for i in range(max_epoch):
        batch_xs, batch_ys = mnist.train.next_batch(100)
        sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob:keep_prob_rate})
        if i % 100 == 0:
            print(compute_accuracy(
                mnist.test.images[:1000], mnist.test.labels[:1000]))

运行结果:

猜你喜欢

转载自blog.csdn.net/qq_42018521/article/details/130351550
今日推荐