卷积层+池化层+全连接层 训练Mnist数据

♥,.*,.♥,.*,.♥,.*,.♥,.*♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥♥,.*,.♥,.*,.♥,.*,.♥,.*♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥

使用卷积层+池化层+全连接层 训练Mnist数据集

# coding:utf-8

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf

'''
step:
1 获取MNIST数据集
输入大小:None*28*28

2 第一层卷积层
输入:占位符 [None, 28*28]    卷积核大小:5×5 步长为1      输入图像的厚度:1       卷积核个数:32
池化层池化大小:2×2             池化方式:最大池化

3 第二层卷积层
输入:None*14*14*32(上一层卷积层的卷积核个数为32 && 池化后图像大小变为14*14)
卷积核大小:5×5 步长为1         输入图像的厚度:32          卷积核个数:64
池化层池化大小:2×2             池化方式:最大池化

4 第三层全连接层
输入:None*7×7*64
权值向量大小:[7*7*64, 1024] (1024为自定义)
输出:[1034,10]

5 softmax层
输入:None*1024
权值向量大小:1024*10
输出:10

6 训练数据 && 验证数据
'''

'''
封装函数
'''
# 权值初始化
def weight_variable(shape):
    return tf.Variable(tf.truncated_normal(shape=shape, stddev=0.1))
# 偏置项初始化
def bias_variable(shape):
    return tf.Variable(tf.constant(0.1, shape=shape))
# 卷积操作 图像与卷积核进行卷积操作
def conv2d(img_array, W):
    # stride 【batch,height, width, channels】
    # 第一个1表示batch维度上滑动步长为1,不跳过任何一个样本
    # height 表示卷积核的垂直滑动步长
    # width 表示卷积核的水平滑动步长
    # 1 channel 表示通道维度上的滑动步长为1,不跳过任何一个颜色通道
    return tf.nn.conv2d(img_array, W, strides=[1, 1, 1, 1], padding='SAME')
# 池化操作
def max_pool_2x2(conv):
    # ksize 池化层kernel大小
    # 分别对应 batch height width channels
    return tf.nn.max_pool(conv, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

if __name__ == '__main__':
    '''
    1 输入Mnist数据集
    '''
    mnist_data = input_data.read_data_sets("MNIST_data", one_hot=True)

    '''
    2 第一层卷积层
    输入:占位符 [None, 28*28]    卷积核大小:5×5 步长为1      输入图像的厚度:1       卷积核个数:32
    池化层池化大小:2×2             池化方式:最大池化
    '''
    x = tf.placeholder(tf.float32, shape=[None, 28*28])
    y_= tf.placeholder(tf.float32, shape=[None, 10])
    # 将输入的图片转换为28*28的大小
    img_array = tf.reshape(x, [-1, 28, 28, 1]) # -1表示形状第一维的大小是根据x自动确定的
    # 5*5 patch 卷积核的大小
    # 1 in_size 图像的厚度
    # 32 out_size 图像的厚度,或者卷积核的个数
    W_conv1 = weight_variable([5, 5, 1, 32])
    b_conv1 = bias_variable([32])
    h_conv1 = tf.nn.relu(conv2d(img_array, W_conv1)+b_conv1)
    h_pool1 = max_pool_2x2(h_conv1)

    '''
    3 第二层卷积层
    输入:None*14*14*32(上一层卷积层的卷积核个数为32 && 池化后图像大小变为14*14)
    卷积核大小:5×5 步长为1         输入图像的厚度:32          卷积核个数:64
    池化层池化大小:2×2             池化方式:最大池化
    '''
    W_conv2 = weight_variable([5, 5, 32, 64])
    b_conv2 = bias_variable([64])
    h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2)+b_conv2)
    h_pool2 = max_pool_2x2(h_conv2)

    '''
    4 第三层全连接层
    输入:None*7*7*64
    权值向量大小:[7*7*64, 1024](1024为自定义)
    输出: [1034, 10]
    '''
    W_fc1 = weight_variable([7*7*64, 1024])
    b_fc1 = bias_variable([1024])
    h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
    h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1)+b_fc1)
    # 使用dropout
    keep_prob = tf.placeholder(tf.float32)
    h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

    '''
    5 softmax层
    输入: None*1024
    权值向量大小: 1024*10
    输出: 10
    '''
    # 将输出的1024维的向量转换为10个类别,对应于10个数字
    W_fc2= weight_variable([1024, 10])
    b_fc2 = bias_variable([10])
    y_conv = tf.matmul(h_fc1_drop,W_fc2)+b_fc2

    '''
    6 训练数据 && 验证数据
    '''
    cross_entry = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
    train_step = tf.train.GradientDescentOptimizer(0.0001).minimize(cross_entry)

    sess = tf.InteractiveSession()
    tf.global_variables_initializer().run()

    correction_prediction = tf.equal(tf.argmax(y_, 1), tf.argmax(y_conv, 1))
    accuracy = tf.reduce_mean(tf.cast(correction_prediction, tf.float32))

    for i in range(20000):
        batch_xs, batch_ys = mnist_data.train.next_batch(100)
        sess.run(train_step, feed_dict={x:batch_xs, y_:batch_ys, keep_prob:0.5})
        if i%100 == 0:
            train_accuracy = accuracy.eval(
                feed_dict={x:mnist_data.test.images[0:100], y_:mnist_data.test.labels[0:100], keep_prob:1.0})
            print("step: %d, accuracy:%g" %(i, train_accuracy))

    print(sess.run(accuracy, feed_dict={x:mnist_data.test.images[0:100], y_:mnist_data.test.labels[0:100], keep_prob:1.0}))



'''
step: 0, accuracy:0.16
step: 100, accuracy:0.22
step: 200, accuracy:0.27
step: 300, accuracy:0.32
step: 400, accuracy:0.34
step: 500, accuracy:0.36
step: 600, accuracy:0.39
step: 700, accuracy:0.42
step: 800, accuracy:0.47
step: 900, accuracy:0.47
step: 1000, accuracy:0.49
step: 1100, accuracy:0.52
step: 1200, accuracy:0.52
step: 1300, accuracy:0.54
step: 1400, accuracy:0.56
step: 1500, accuracy:0.57
step: 1600, accuracy:0.58
step: 1700, accuracy:0.6
step: 1800, accuracy:0.62
step: 1900, accuracy:0.61
step: 2000, accuracy:0.62
step: 2100, accuracy:0.63
step: 2200, accuracy:0.66
step: 2300, accuracy:0.67
step: 2400, accuracy:0.68
step: 2500, accuracy:0.7
step: 2600, accuracy:0.7
step: 2700, accuracy:0.72
step: 2800, accuracy:0.73
step: 2900, accuracy:0.75
step: 3000, accuracy:0.76
step: 3100, accuracy:0.76
step: 3200, accuracy:0.76
step: 3300, accuracy:0.77
step: 3400, accuracy:0.77
step: 3500, accuracy:0.77
step: 3600, accuracy:0.78
step: 3700, accuracy:0.78
step: 3800, accuracy:0.78
step: 3900, accuracy:0.78
step: 4000, accuracy:0.78
step: 4100, accuracy:0.78
step: 4200, accuracy:0.78
step: 4300, accuracy:0.79
step: 4400, accuracy:0.79
step: 4500, accuracy:0.78
step: 4600, accuracy:0.79
step: 4700, accuracy:0.82
step: 4800, accuracy:0.81
step: 4900, accuracy:0.81
step: 5000, accuracy:0.82
step: 5100, accuracy:0.84
step: 5200, accuracy:0.84
step: 5300, accuracy:0.84
step: 5400, accuracy:0.84
step: 5500, accuracy:0.84
step: 5600, accuracy:0.84
step: 5700, accuracy:0.85
step: 5800, accuracy:0.85
step: 5900, accuracy:0.85
step: 6000, accuracy:0.85
step: 6100, accuracy:0.85
step: 6200, accuracy:0.85
step: 6300, accuracy:0.85
step: 6400, accuracy:0.85
step: 6500, accuracy:0.85
step: 6600, accuracy:0.87
step: 6700, accuracy:0.86
step: 6800, accuracy:0.86
step: 6900, accuracy:0.85
step: 7000, accuracy:0.87
step: 7100, accuracy:0.87
step: 7200, accuracy:0.87
step: 7300, accuracy:0.87
step: 7400, accuracy:0.88
step: 7500, accuracy:0.88
step: 7600, accuracy:0.88
step: 7700, accuracy:0.88
step: 7800, accuracy:0.88
step: 7900, accuracy:0.88
step: 8000, accuracy:0.88
step: 8100, accuracy:0.88
step: 8200, accuracy:0.88
step: 8300, accuracy:0.88
step: 8400, accuracy:0.88
step: 8500, accuracy:0.88
step: 8600, accuracy:0.88
step: 8700, accuracy:0.88
step: 8800, accuracy:0.88
step: 8900, accuracy:0.88
step: 9000, accuracy:0.88
step: 9100, accuracy:0.88
step: 9200, accuracy:0.89
step: 9300, accuracy:0.89
step: 9400, accuracy:0.89
step: 9500, accuracy:0.89
step: 9600, accuracy:0.89
step: 9700, accuracy:0.89
step: 9800, accuracy:0.89
step: 9900, accuracy:0.89
step: 10000, accuracy:0.89
step: 10100, accuracy:0.89
step: 10200, accuracy:0.89
step: 10300, accuracy:0.89
step: 10400, accuracy:0.89
step: 10500, accuracy:0.89
step: 10600, accuracy:0.89
step: 10700, accuracy:0.89
step: 10800, accuracy:0.89
step: 10900, accuracy:0.89
step: 11000, accuracy:0.89
step: 11100, accuracy:0.89
step: 11200, accuracy:0.89
step: 11300, accuracy:0.89
step: 11400, accuracy:0.89
step: 11500, accuracy:0.89
step: 11600, accuracy:0.89
step: 11700, accuracy:0.89
...
'''

♥,.*,.♥,.*,.♥,.*,.♥,.*♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥♥,.*,.♥,.*,.♥,.*,.♥,.*♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥,.*,.♥

广告时间:

本宝宝开通了一个公众号,记录日常的深度学习和强化学习笔记。

希望大家可以共同进步,嘻嘻嘻!求关注,爱你呦!

KeepYourAims

发布了125 篇原创文章 · 获赞 126 · 访问量 19万+

猜你喜欢

转载自blog.csdn.net/Valieli/article/details/103894694