利用CNN卷积神经网络进行手写MNIST数字识别

这几天学了卷积

跟搞ACM学的卷积不是一个东西吧

我反正没感到一点相关

准确率97.2%,还不错,可以接受

上代码吧:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def compute_accuracy(v_xs, v_ys):
    global prediction
    y_pre = sess.run(prediction, feed_dict={xs: v_xs, keep_prob: 1})
    correct_prediction = tf.equal(tf.argmax(y_pre,1), tf.argmax(v_ys,1))
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys, keep_prob: 1})
    return result

def W(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def Conv2d(x,y):
    return tf.nn.conv2d(x, y, strides=[1, 1, 1, 1], padding='SAME')

def Pool(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

xs = tf.placeholder(tf.float32,[None,784])
ys = tf.placeholder(tf.float32,[None,10])
x_imag = tf.reshape(xs,[-1,28,28,1])
keep_prob = tf.placeholder(tf.float32)

#conv1
W1 = W([5,5,1,32])
B1 = tf.Variable(tf.constant(0.1,shape=[32]))
H1 = tf.nn.relu(Conv2d(x_imag,W1)+B1)
P1 = Pool(H1)
#conv2
W2 = W([5,5,32,64])
B2 = tf.Variable(tf.constant(0.1,shape=[64]))
H2 = tf.nn.relu(Conv2d(P1,W2)+B2)
P2 = Pool(H2)
#F1
W3 = W([7*7*64,1024])
B3 = tf.Variable(tf.constant(0.1,shape=[1024]))
H31 = tf.reshape(P2,[-1,7*7*64])
H32 = tf.nn.relu(tf.matmul(H31,W3)+B3)
H33 = tf.nn.dropout(H32,keep_prob)
#F2
W4 = W([1024,10])
B4 = tf.Variable(tf.constant(0.1,shape=[10]))
H4 = tf.matmul(H33,W4)+B4
prediction = tf.nn.softmax(H4)

cross_entropy =  tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
                                              reduction_indices=[1]))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

for i in range(1000):
    batch_xs,batch_ys = mnist.train.next_batch(100)
    sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys,keep_prob:0.5})
    if(i%50==0):
        print(compute_accuracy(mnist.test.images[:1000], mnist.test.labels[:1000]))









猜你喜欢

转载自blog.csdn.net/Gipsy_Danger/article/details/81516936