网络优化方法--Dropout

网络优化方法--Dropout

1、Dropout介绍

  Dropout 也是一种用于抵抗过拟合的技术,它试图改变网络本身来对网络进行优化。我 们先来了解一下它的工作机制,当我们训练一个普通的神经网络时,网络的结构可能如图所示。

image-20220508224755075

  Dropout 通常是在神经网络隐藏层的部分使用,使用的时候会临时关闭掉一部分的神经 元,我们可以通过一个参数来控制神经元被关闭的概率,网络结构如图所示。

image-20220508224829730

更详细的流程如下:

  1. 在模型训练阶段我们可以先给 Dropout 参数设置一个值,例如 0.4。意思是 大约 60%的神经元是工作的,大约 40%神经元是不工作的
  2. 给需要进行Dropout的神经网络层的每一个神经元生成一个0-1 的随机数(一 般是对隐藏层进行 Dropout)。如果神经元的随机数小于 0.6,那么该神经元就设置为 工作状态的;如果神经元的随机数大于等于 0.6,那么该神经元就设置为不工作的,不工作状态的意思就是不参与计算和训练,可以当这个神经元不存在。
  3. 设置好一部分神经元工作一部分神经元不工作之后,我们会发现神经网络的输 出值会发现变化,如上图,如果隐藏层有一半不工作,那么网络输出值就会比原来的值要小,因为计算 WX+b 时,如果 W 矩阵中,有一部分的值变成 0,那么最后 的计算结果肯定会变小。所以为了使用 Dropout 的网络层神经元信号的总和不会发生 太大的变化,对于工作的神经元的输出信号还需要除以 0.4。
  4. 训练阶段重复 1-3 步骤,每一次都随机选择部分的神经元参与训练。
  5. 在测试阶段所有的神经元都参与计算。

   Dropout 为什么会起作用呢?这个问题很难通过数学推导来证明。我们在介绍 ReLU 激 活函数的时候有提到过神经网络的信号是冗余的,神经网络在做预测时并不需要隐藏层所有神 经元都工作,只需要一部分隐藏层神经元工作即可。我们可以抽象地来理解 Dropout,当我们 使用 Dropout 的时候,就有点像我们在训练很多不同的结构更简单的神经网络,最后测试阶 段再综合所有的网络结构得到结果。或者另外一种理解方式是我们使用 Dropout 的时候减少 了神经元之间的相互关联,同时强制网络使用更少的特征来做预测,可以增加模型的健壮性。

  除了这两种理解方式之外还可以有其他的很多理解方式,深度学习中很多技巧都是不能用 数学推导得到同时又比较难理解的。但重要的是这些技巧在实际应用中可以帮助我们得到更好 的结果。

  Dropout 比较适合应用于只有少量数据但是需要训练复杂模型的场景,这类场景在图像 领域比较常见,所以 Dropout 经常用于图像领域。

2、Dropout程序

  这里我们而将看到一个Dropout在MNIST数据集识别中的应用,我们建立两个模型,一个使用Dropout,另一个不使用Dropout,对比两个模型的收敛速度。

代码在Jupyter Notebook中调试。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Flatten
from tensorflow.keras.optimizers import SGD
import matplotlib.pyplot as plt
import numpy as np

# 载入数据集
mnist = tf.keras.datasets.mnist
# 载入训练集和测试集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对训练集和测试集的数据进行归一化处理,有助于提升模型训练速度
x_train, x_test = x_train / 255.0, x_test / 255.0
# 把训练集和测试集的标签转为独热编码
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)

# 模型定义,model1使用Dropout
# Dropout(0.4)表示隐藏层40%神经元不工作
model1 = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(units=200,activation='tanh'),
        Dropout(0.4),
        Dense(units=100,activation='tanh'),
        Dropout(0.4),
        Dense(units=10,activation='softmax')
        ])

# 在定义一个一模一样的模型用于对比测试,model2不使用Dropout
# Dropout(0)表示隐藏层所有神经元都工作,相当于没有Dropout
model2 = Sequential([
        Flatten(input_shape=(28, 28)),
        Dense(units=200,activation='tanh'),
        Dropout(0),
        Dense(units=100,activation='tanh'),
        Dropout(0),
        Dense(units=10,activation='softmax')
        ])

# sgd定义随机梯度下降法优化器
# loss='categorical_crossentropy'定义交叉熵代价函数
# metrics=['accuracy']模型在训练的过程中同时计算准确率
sgd = SGD(0.2)
model1.compile(optimizer=sgd,
              loss='categorical_crossentropy',
              metrics=['accuracy'])
model2.compile(optimizer=sgd,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# 传入训练集数据和标签训练模型
# 周期大小为30(把所有训练集数据训练一次称为训练一个周期)
epochs = 30
# 批次大小为32(每次训练模型传入32个数据进行训练)
batch_size=32
# validation_data设置验证集数据
# 先训练model1
history1 = model1.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))
# 再训练model2
history2 = model2.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, validation_data=(x_test,y_test))

  训练过程:

Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.4173 - accuracy: 0.8728 - val_loss: 0.2200 - val_accuracy: 0.9337
Epoch 2/30
60000/60000 [==============================] - 5s 78us/sample - loss: 0.2786 - accuracy: 0.9171 - val_loss: 0.1616 - val_accuracy: 0.9516
Epoch 3/30
60000/60000 [==============================] - 4s 73us/sample - loss: 0.2384 - accuracy: 0.9293 - val_loss: 0.1603 - val_accuracy: 0.9519
Epoch 4/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.2182 - accuracy: 0.9347 - val_loss: 0.1393 - val_accuracy: 0.9577
Epoch 5/30
60000/60000 [==============================] - 4s 74us/sample - loss: 0.2014 - accuracy: 0.9400 - val_loss: 0.1257 - val_accuracy: 0.9626
Epoch 6/30
60000/60000 [==============================] - 5s 75us/sample - loss: 0.1881 - accuracy: 0.9453 - val_loss: 0.1236 - val_accuracy: 0.9651
Epoch 7/30
60000/60000 [==============================] - 5s 83us/sample - loss: 0.1748 - accuracy: 0.9483 - val_loss: 0.1107 - val_accuracy: 0.9670
Epoch 8/30
60000/60000 [==============================] - 6s 104us/sample - loss: 0.1683 - accuracy: 0.9494 - val_loss: 0.1131 - val_accuracy: 0.9662
Epoch 9/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1597 - accuracy: 0.9517 - val_loss: 0.1066 - val_accuracy: 0.9677
Epoch 10/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1534 - accuracy: 0.9541 - val_loss: 0.0945 - val_accuracy: 0.9709
Epoch 11/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1511 - accuracy: 0.9547 - val_loss: 0.1054 - val_accuracy: 0.9674
Epoch 12/30
60000/60000 [==============================] - 6s 97us/sample - loss: 0.1481 - accuracy: 0.9548 - val_loss: 0.0930 - val_accuracy: 0.9730
Epoch 13/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1406 - accuracy: 0.9586 - val_loss: 0.0937 - val_accuracy: 0.9707
Epoch 14/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1381 - accuracy: 0.9588 - val_loss: 0.0904 - val_accuracy: 0.9735
Epoch 15/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1348 - accuracy: 0.9597 - val_loss: 0.0934 - val_accuracy: 0.9724
Epoch 16/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1304 - accuracy: 0.9614 - val_loss: 0.0865 - val_accuracy: 0.9747
Epoch 17/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1262 - accuracy: 0.9628 - val_loss: 0.0871 - val_accuracy: 0.9745
Epoch 18/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1255 - accuracy: 0.9628 - val_loss: 0.0856 - val_accuracy: 0.9735
Epoch 19/30
60000/60000 [==============================] - 6s 100us/sample - loss: 0.1248 - accuracy: 0.9616 - val_loss: 0.0826 - val_accuracy: 0.9747
Epoch 20/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1180 - accuracy: 0.9651 - val_loss: 0.0847 - val_accuracy: 0.9752
Epoch 21/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1163 - accuracy: 0.9648 - val_loss: 0.0869 - val_accuracy: 0.9747
Epoch 22/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1171 - accuracy: 0.9650 - val_loss: 0.0813 - val_accuracy: 0.9764
Epoch 23/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1160 - accuracy: 0.9647 - val_loss: 0.0872 - val_accuracy: 0.9746
Epoch 24/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1100 - accuracy: 0.9664 - val_loss: 0.0850 - val_accuracy: 0.9759
Epoch 25/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1095 - accuracy: 0.9671 - val_loss: 0.0815 - val_accuracy: 0.9769
Epoch 26/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1087 - accuracy: 0.9668 - val_loss: 0.0799 - val_accuracy: 0.9774
Epoch 27/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.1084 - accuracy: 0.9674 - val_loss: 0.0811 - val_accuracy: 0.9779
Epoch 28/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1055 - accuracy: 0.9683 - val_loss: 0.0794 - val_accuracy: 0.9761
Epoch 29/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.1030 - accuracy: 0.9689 - val_loss: 0.0803 - val_accuracy: 0.9767
Epoch 30/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.1036 - accuracy: 0.9682 - val_loss: 0.0770 - val_accuracy: 0.9777
Train on 60000 samples, validate on 10000 samples
Epoch 1/30
60000/60000 [==============================] - 6s 99us/sample - loss: 0.2536 - accuracy: 0.9230 - val_loss: 0.1502 - val_accuracy: 0.9537
Epoch 2/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.1172 - accuracy: 0.9641 - val_loss: 0.1013 - val_accuracy: 0.9688
Epoch 3/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.0809 - accuracy: 0.9757 - val_loss: 0.1021 - val_accuracy: 0.9659
Epoch 4/30
60000/60000 [==============================] - 6s 94us/sample - loss: 0.0598 - accuracy: 0.9816 - val_loss: 0.0958 - val_accuracy: 0.9699
Epoch 5/30
60000/60000 [==============================] - 6s 93us/sample - loss: 0.0457 - accuracy: 0.9857 - val_loss: 0.0867 - val_accuracy: 0.9749
Epoch 6/30
60000/60000 [==============================] - 6s 93us/sample - loss: 0.0353 - accuracy: 0.9892 - val_loss: 0.0729 - val_accuracy: 0.9770
Epoch 7/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.0244 - accuracy: 0.9932 - val_loss: 0.0774 - val_accuracy: 0.9762
Epoch 8/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.0191 - accuracy: 0.9947 - val_loss: 0.0688 - val_accuracy: 0.9782
Epoch 9/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.0141 - accuracy: 0.9966 - val_loss: 0.0946 - val_accuracy: 0.9702
Epoch 10/30
60000/60000 [==============================] - 7s 111us/sample - loss: 0.0097 - accuracy: 0.9978 - val_loss: 0.0704 - val_accuracy: 0.9785
Epoch 11/30
60000/60000 [==============================] - 6s 107us/sample - loss: 0.0058 - accuracy: 0.9991 - val_loss: 0.0629 - val_accuracy: 0.9813
Epoch 12/30
60000/60000 [==============================] - 6s 99us/sample - loss: 0.0043 - accuracy: 0.9995 - val_loss: 0.0684 - val_accuracy: 0.9800
Epoch 13/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.0030 - accuracy: 0.9998 - val_loss: 0.0646 - val_accuracy: 0.9808
Epoch 14/30
60000/60000 [==============================] - 6s 98us/sample - loss: 0.0022 - accuracy: 0.9999 - val_loss: 0.0643 - val_accuracy: 0.9815
Epoch 15/30
60000/60000 [==============================] - 6s 106us/sample - loss: 0.0017 - accuracy: 1.0000 - val_loss: 0.0678 - val_accuracy: 0.9804
Epoch 16/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.0015 - accuracy: 1.0000 - val_loss: 0.0660 - val_accuracy: 0.9811
Epoch 17/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.0013 - accuracy: 1.0000 - val_loss: 0.0667 - val_accuracy: 0.9812
Epoch 18/30
60000/60000 [==============================] - 6s 95us/sample - loss: 0.0011 - accuracy: 1.0000 - val_loss: 0.0670 - val_accuracy: 0.9814
Epoch 19/30
60000/60000 [==============================] - 6s 96us/sample - loss: 0.0010 - accuracy: 1.0000 - val_loss: 0.0668 - val_accuracy: 0.9814
Epoch 20/30
60000/60000 [==============================] - 6s 95us/sample - loss: 9.3235e-04 - accuracy: 1.0000 - val_loss: 0.0676 - val_accuracy: 0.9817
Epoch 21/30
60000/60000 [==============================] - 6s 95us/sample - loss: 8.5067e-04 - accuracy: 1.0000 - val_loss: 0.0673 - val_accuracy: 0.9815
Epoch 22/30
60000/60000 [==============================] - 6s 95us/sample - loss: 7.8290e-04 - accuracy: 1.0000 - val_loss: 0.0688 - val_accuracy: 0.9813
Epoch 23/30
60000/60000 [==============================] - 6s 95us/sample - loss: 7.2826e-04 - accuracy: 1.0000 - val_loss: 0.0682 - val_accuracy: 0.9814
Epoch 24/30
60000/60000 [==============================] - 6s 97us/sample - loss: 6.8046e-04 - accuracy: 1.0000 - val_loss: 0.0691 - val_accuracy: 0.9811
Epoch 25/30
60000/60000 [==============================] - 5s 91us/sample - loss: 6.3994e-04 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9812
Epoch 26/30
60000/60000 [==============================] - 5s 91us/sample - loss: 5.9906e-04 - accuracy: 1.0000 - val_loss: 0.0699 - val_accuracy: 0.9812
Epoch 27/30
60000/60000 [==============================] - 6s 92us/sample - loss: 5.6810e-04 - accuracy: 1.0000 - val_loss: 0.0696 - val_accuracy: 0.9815
Epoch 28/30
60000/60000 [==============================] - 6s 98us/sample - loss: 5.3810e-04 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9812
Epoch 29/30
60000/60000 [==============================] - 6s 96us/sample - loss: 5.1041e-04 - accuracy: 1.0000 - val_loss: 0.0707 - val_accuracy: 0.9811
Epoch 30/30
60000/60000 [==============================] - 6s 96us/sample - loss: 4.8516e-04 - accuracy: 1.0000 - val_loss: 0.0712 - val_accuracy: 0.9819

  这里是用两个模型对比的,所以训练过程包含了两个模型的结果。

# 画出model1验证集准确率曲线图
plt.plot(np.arange(epochs),history1.history['val_accuracy'],c='b',label='Dropout')
# 画出model2验证集准确率曲线图
plt.plot(np.arange(epochs),history2.history['val_accuracy'],c='y',label='FC')
# 图例
plt.legend()
# x坐标描述
plt.xlabel('epochs')
# y坐标描述
plt.ylabel('accuracy')
# 显示图像
plt.show()

image-20220508225451984

  模型训练结果前 1-30 周期是使用了 Dropout 的结果,后面的 1-30 周期是没有使用 Dropout 的结果。观察结果我们发现使用了 Dropout 之后训练集准确率和验证集的准确率相差并不是很大,所以能看出 Dropout 确实是可以起到抵抗过拟合的作用。我们还可以发现一个有趣的现象就是前 1-30 周期 model1 的验证集准确率还高于训练集的准确率,这是因为模 型在计算训练集准确率的时候模型还在使用 Dropout,在计算验证集准确率的时候已经不使 用 Dropout 了。使用 Dropout 的时候模型的准确率会稍微降低一些。同时我们也可以发现, 不用 Dropout 的 model2 中测试集的准确率看起来比使用 Dropout 的 model1 要更高。

  事实上使用 Dropout 之后模型的收敛速度会变慢一些,所以需要更多的训练次数才能得到最好的结果。

  这里不用 Dropout 的 model2 验证集训练 30 个周期最高准确率大概 是 98.2%左右;使用 Dropout 的 model1 如果训练足够多的周期,验证集最高准确率可以达 到 98.8%左右。

猜你喜欢

转载自blog.csdn.net/qq_43753724/article/details/124656662