Tensorflow实现 Dropout

版权声明:本文为博主原创文章,转载请注明出处! https://blog.csdn.net/PoGeN1/article/details/84633592

程序代码(dropout=1.0)

#导入相应模块
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#读入数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
#设置参数
batch_size=120#批大小
n_batch=mnist.train.num_examples//batch_size#总共有多少批次
#placeholder
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,10])
keep_prob=tf.placeholder(tf.float32)
#初始化权重偏置
W1=tf.Variable(tf.truncated_normal([784,2000],stddev=0.1))
b1=tf.Variable(tf.zeros([2000])+0.1)
L1=tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop=tf.nn.dropout(L1,keep_prob)

W2=tf.Variable(tf.truncated_normal([2000,2000],stddev=0.1))
b2=tf.Variable(tf.zeros([2000])+0.1)
L2=tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop=tf.nn.dropout(L2,keep_prob)

W3=tf.Variable(tf.truncated_normal([2000,1000],stddev=0.1))
b3=tf.Variable(tf.zeros([1000])+0.1)
L3=tf.nn.tanh(tf.matmul(L2_drop,W3)+b3)
L3_drop=tf.nn.dropout(L3,keep_prob)

W4=tf.Variable(tf.truncated_normal([1000,10],stddev=0.1))
b4=tf.Variable(tf.zeros([10])+0.1)
prediction=tf.nn.softmax(tf.matmul(L3_drop,W4)+b4)
#定义一个loss
loss=tf.reduce_mean(tf.square(y-prediction))
# 交叉熵写法:loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
#梯度下降法
train=tf.train.GradientDescentOptimizer(0.2).minimize(loss)
init=tf.global_variables_initializer()
#预测精度
current_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))
accuracy=tf.reduce_mean(tf.cast(current_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(31):
        for _ in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train,feed_dict={x:batch_xs,y:batch_ys,keep_prob:1.0})
        test_acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0})
        train_acc = sess.run(accuracy, feed_dict={x: mnist.train.images, y: mnist.train.labels,keep_prob:1.0})
        print('Iter:'+str(epoch)+',Test_Accuracy:'+str(test_acc)+',Training_Accuracy"'+str(train_acc))

输出结果(dropout=1.0,也就是说没有用dropout)

Iter:0,Test_Accuracy:0.9107,Training_Accuracy"0.921
Iter:1,Test_Accuracy:0.9289,Training_Accuracy"0.9506182
Iter:2,Test_Accuracy:0.9357,Training_Accuracy"0.9636727
Iter:3,Test_Accuracy:0.9426,Training_Accuracy"0.97185457
Iter:4,Test_Accuracy:0.9447,Training_Accuracy"0.9766727
Iter:5,Test_Accuracy:0.947,Training_Accuracy"0.9801818
Iter:6,Test_Accuracy:0.9493,Training_Accuracy"0.9828909
Iter:7,Test_Accuracy:0.9507,Training_Accuracy"0.98489094
Iter:8,Test_Accuracy:0.9534,Training_Accuracy"0.98614544
Iter:9,Test_Accuracy:0.9537,Training_Accuracy"0.9876364
Iter:10,Test_Accuracy:0.9554,Training_Accuracy"0.98865455
Iter:11,Test_Accuracy:0.955,Training_Accuracy"0.98947275
Iter:12,Test_Accuracy:0.9573,Training_Accuracy"0.99045455
Iter:13,Test_Accuracy:0.9575,Training_Accuracy"0.9911091
Iter:14,Test_Accuracy:0.9581,Training_Accuracy"0.99152726
Iter:15,Test_Accuracy:0.9583,Training_Accuracy"0.9921273
Iter:16,Test_Accuracy:0.958,Training_Accuracy"0.99252725
Iter:17,Test_Accuracy:0.9593,Training_Accuracy"0.9930182
Iter:18,Test_Accuracy:0.9582,Training_Accuracy"0.9934
Iter:19,Test_Accuracy:0.9586,Training_Accuracy"0.99363637
Iter:20,Test_Accuracy:0.9587,Training_Accuracy"0.9938727
Iter:21,Test_Accuracy:0.9591,Training_Accuracy"0.9942182
Iter:22,Test_Accuracy:0.9591,Training_Accuracy"0.9944182
Iter:23,Test_Accuracy:0.9596,Training_Accuracy"0.99463636
Iter:24,Test_Accuracy:0.9602,Training_Accuracy"0.9949091
Iter:25,Test_Accuracy:0.9603,Training_Accuracy"0.9950909
Iter:26,Test_Accuracy:0.9605,Training_Accuracy"0.99521816
Iter:27,Test_Accuracy:0.9605,Training_Accuracy"0.9953273
Iter:28,Test_Accuracy:0.9611,Training_Accuracy"0.99554545
Iter:29,Test_Accuracy:0.9611,Training_Accuracy"0.9956545
Iter:30,Test_Accuracy:0.9617,Training_Accuracy"0.9958182

输出结果(dropout=0.7):

只要把程序中的keep_prob=0.7即可。输出结果表明用dropout可以减少过拟合,因为测试精度和训练精度一直相差不大,不像上面dropout=1的情况,测试精度与训练精度相差比较大,差了有四个百分点,明显过拟合了。


Iter:0,Test_Accuracy:0.8779,Training_Accuracy"0.8677273
Iter:1,Test_Accuracy:0.8986,Training_Accuracy"0.894
Iter:2,Test_Accuracy:0.9111,Training_Accuracy"0.9067818
Iter:3,Test_Accuracy:0.9181,Training_Accuracy"0.91587275
Iter:4,Test_Accuracy:0.9215,Training_Accuracy"0.91916364
Iter:5,Test_Accuracy:0.9255,Training_Accuracy"0.92530906
Iter:6,Test_Accuracy:0.9307,Training_Accuracy"0.92807275
Iter:7,Test_Accuracy:0.9341,Training_Accuracy"0.9322364
Iter:8,Test_Accuracy:0.9344,Training_Accuracy"0.93485457
Iter:9,Test_Accuracy:0.9354,Training_Accuracy"0.9373818
Iter:10,Test_Accuracy:0.9397,Training_Accuracy"0.9401091
Iter:11,Test_Accuracy:0.9373,Training_Accuracy"0.9408909
Iter:12,Test_Accuracy:0.9402,Training_Accuracy"0.9438909
Iter:13,Test_Accuracy:0.9397,Training_Accuracy"0.9456364
Iter:14,Test_Accuracy:0.9426,Training_Accuracy"0.9463636
Iter:15,Test_Accuracy:0.9441,Training_Accuracy"0.9484
Iter:16,Test_Accuracy:0.9441,Training_Accuracy"0.9497455
Iter:17,Test_Accuracy:0.9489,Training_Accuracy"0.9516909
Iter:18,Test_Accuracy:0.9491,Training_Accuracy"0.9518909
Iter:19,Test_Accuracy:0.9479,Training_Accuracy"0.9530182
Iter:20,Test_Accuracy:0.9528,Training_Accuracy"0.95443636
Iter:21,Test_Accuracy:0.951,Training_Accuracy"0.95565456
Iter:22,Test_Accuracy:0.9502,Training_Accuracy"0.9562727
Iter:23,Test_Accuracy:0.9518,Training_Accuracy"0.9575091
Iter:24,Test_Accuracy:0.9514,Training_Accuracy"0.9581818
Iter:25,Test_Accuracy:0.9514,Training_Accuracy"0.9590727
Iter:26,Test_Accuracy:0.9544,Training_Accuracy"0.96043634
Iter:27,Test_Accuracy:0.9547,Training_Accuracy"0.96016365
Iter:28,Test_Accuracy:0.9543,Training_Accuracy"0.9614546
Iter:29,Test_Accuracy:0.9545,Training_Accuracy"0.9618545
Iter:30,Test_Accuracy:0.9561,Training_Accuracy"0.9617818

猜你喜欢

转载自blog.csdn.net/PoGeN1/article/details/84633592