TensorFlow实现LeNet

TensorFlow实现LeNet

首先是数据的导入,这里使用的是MNIST数据集:

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import time

声明输入图片的数据和类别:

x = tf.placeholder('float',[None,784])
y_ = tf.placeholder('float',[None,10])

将一维的MNIST数据转为二维图像矩阵:

x_image = tf.reshape(x,[-1, 28, 28, 1])

第一个卷积层处理:

filter1 = tf.Variable(tf.truncated_normal([5,5,1,6]))
bias1 = tf.Variable(tf.truncated_normal([6]))
conv1 = tf.nn.conv2d(x_image,filter1, strides = [1,1,1,1], padding='SAME')
h_conv1 =tf.nn.sigmoid(conv1 + bias1)

第一个池化层:

maxPool2 = tf.nn.max_pool(h_conv1,ksize=[1,2,2,1],strides=[1,2,2,1],padding = 'SAME')

第三层卷积层:

filter2 = tf.Variable(tf.truncated_normal([5,5,6,16]))
bias2 = tf.Variable(tf.truncated_normal([16]))
conv2 = tf.nn.conv2d(maxPool2, filter2,strides=[1,1,1,1],padding = 'SAME')
h_conv2 = tf.nn.sigmoid(conv2+bias2)

maxPool3 = tf.nn.max_pool(h_conv2,ksize = [1,2,2,1],strides=[1,2,2,1],padding='SAME')

filter3 = tf.Variable(tf.truncated_normal([5,5,16,120]))
bias3 = tf.Variable(tf.truncated_normal([120]))
conv3 = tf.nn.conv2d(maxPool3, filter3, strides=[1,1,1,1],padding="SAME")
h_conv3 = tf.nn.sigmoid(conv3+bias3)

全连接层:

W_fc1 = tf.Variable(tf.truncated_normal([7*7*120,80]))
b_fc1 = tf.Variable(tf.truncated_normal([80]))

h_pool2_flat = tf.reshape(h_conv3,[-1, 7*7*120])
h_fc1 = tf.nn.sigmoid(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

W_fc2 = tf.Variable(tf.truncated_normal([80,10]))
b_fc2 = tf.Variable(tf.truncated_normal([10]))

最后一个输出层,使用softmax进行概率的计算:

y_conv = tf.nn.softmax(tf.matmul(h_fc1,W_fc2) + b_fc2)

损失函数,及优化算法:

cross_entropy = -tf.reduce_sum(y_ * tf.log(y_conv))
train_step = tf.train.GradientDescentOptimizer(0.001).minimize(cross_entropy)

创建会话,进行训练:

sess = tf.Session()
correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
init = tf.global_variables_initializer()
sess.run(init)

mnist_data_set = input_data.read_data_sets('E:\Python Project\mnist\MNIST_data',one_hot=True)

start_time = time.time()

for i in range(200000):
    batch_xs, batch_ys = mnist_data_set.train.next_batch(200)
    if i%5 == 0:
        train_accuracy =  sess.run(accuracy, feed_dict={x:batch_xs,y_:batch_ys})
        print("step %d, training accuracy %g" %(i ,train_accuracy))


        end_time = time.time()
        print('time:',(end_time - start_time))
        start_time = end_time
        sess.run(train_step, feed_dict = {x:batch_xs, y_:batch_ys})

猜你喜欢

转载自blog.csdn.net/moge19/article/details/80466002