TensorFlow MNIST 手写数字识别之过拟合

1. 过拟合 overfitting 问题

什么是过拟合呢?

用实际生活中的一个例子来比喻一下过拟合现象. 说白了, 就是机器学习模型于自信. 已经到了自负的阶段了. 那自负的坏处, 大家也知道, 就是在自己的小圈子里表现非凡, 不过在现实的大圈子里却往往处处碰壁. 所以在这个简介里, 我们把自负和过拟合画上等号.

学习模型可能太满足了所有的训练数据,所以导致在实际数据中误差陡增,如下图,绿的的线是过拟合的训练结果,黑色的分界线是我门期望的结果。

这里写图片描述

解决过拟合

  • 方法一: 增加数据量, 大部分过拟合产生的原因是因为数据量太少了. 如果我们有成千上万的数据, 红线也会慢慢被拉直, 变得没那么扭曲 .

  • 方法二:运用正规化. L1, l2 regularization等等, 这些方法适用于大多数的机器学习, 包括神经网络. 他们的做法大同小异, 我们简化机器学习的关键公式为 y=Wx . W为机器需要学习到的各种参数. 在过拟合中, W 的值往往变化得特别大或特别小. 为了不让W变化太大, 我们在计算误差上做些手脚. 原始的 cost 误差是这样计算, cost = 预测值-真实值的平方. 如果 W 变得太大, 我们就让 cost 也跟着变大, 变成一种惩罚机制. 所以我们把 W 自己考虑进来. 这里 abs 是绝对值. 这一种形式的 正规化, 叫做 l1 正规化. L2 正规化和 l1 类似, 只是绝对值换成了平方. 其他的l3, l4 也都是换成了立方和4次方等等. 形式类似. 用这些方法,我们就能保证让学出来的线条不会过于扭曲.

  • dropout: 在训练的时候, 我们随机忽略掉一些神经元和神经联结 , 是这个神经网络变得”不完整”. 用一个不完整的神经网络训练一次.

    到第二次再随机忽略另一些, 变成另一个不完整的神经网络. 有了这些随机 drop 掉的规则, 我们可以想象其实每次训练的时候, 我们都让每一次预测结果都不会依赖于其中某部分特定的神经元.

2. 在MNIST中解决过拟合

关于处理数据和构建神经网络 请参照博客 TensorFlow 入门之第一个神经网络和训练 MNIST

代码使用了tensorboard 可视化(关于tensorboard 有时间会专门介绍)来观察 loss 函数的变化,使用 dropout来解决过拟合问题。

下面是主要代码

扫描二维码关注公众号,回复: 1584942 查看本文章
# -*- coding: utf-8 -*-
# 用 drop out 解决 Overfitting 问题
import 手写数字识别.input_data  
import tensorflow as tf
mnist = 手写数字识别.input_data.read_data_sets("手写数字识别/MNIST_data/", one_hot=True)  

# 添加神经层的函数def add_layer(),它有四个参数:输入值、输入的大小、输出的大小和激励函数,我们设定默认的激励函数是None。也就是线性函数
def add_layer(inputs, in_size, out_size,layer_name, activation_function=None):
    # 定义权重,尽量是一个随机变量
    # 因为在生成初始参数时,随机变量(normal distribution) 会比全部为0要好很多,所以我们这里的weights 是一个 in_size行, out_size列的随机变量矩阵。   
    Weights = tf.Variable(tf.random_normal([in_size, out_size]))
    # 在机器学习中,biases的推荐值不为0,所以我们这里是在0向量的基础上又加了0.1。
    biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
    # 定义Wx_plus_b, 即神经网络未激活的值(预测的值)。其中,tf.matmul()是矩阵的乘法。
    Wx_plus_b = tf.matmul(inputs, Weights) + biases
    #  用dropout来解决过拟合 问题
    Wx_plus_b = tf.nn.dropout(Wx_plus_b, keep_prob)
    # activation_function ——激励函数(激励函数是非线性方程)为None时(线性关系),输出就是当前的预测值——Wx_plus_b,
    # 不为None时,就把Wx_plus_b传到activation_function()函数中得到输出。
    if activation_function is None:
        outputs = Wx_plus_b
    else:
        # 返回输出
        outputs = activation_function(Wx_plus_b)
        tf.summary.histogram(layer_name + '/outputs', outputs)
    return outputs

# 保留数据 keep_prob
keep_prob = tf.placeholder(tf.float32)
xs = tf.placeholder(tf.float32,[None, 784]) #图像输入向量  每个图片有784 (28 *28) 个像素点
ys = tf.placeholder(tf.float32, [None,10]) #每个例子有10 个输出

# prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)
l1 = add_layer(xs, 784, 50, 'l1', activation_function=tf.nn.tanh)
prediction = add_layer(l1, 50, 10,'l2' ,activation_function=tf.nn.softmax)
#loss函数(即最优化目标函数)选用交叉熵函数。交叉熵用来衡量预测值和真实值的相似程度,如果完全相同,它们的交叉熵等于零 ,所以loss 越小 学的好
#分类一般都是 softmax+ cross_entropy
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1]))
tf.summary.scalar('loss', cross_entropy)

#train方法(最优化算法)采用梯度下降法。  优化器 如何让机器学习提升它的准确率。 tf.train.GradientDescentOptimizer()中的值(学习的效率)通常都小于1
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.Session()

merged = tf.summary.merge_all() # tensorflow >= 0.12
train_writer = tf.summary.FileWriter("/Users/yangyibo/test/logs/train",sess.graph)
test_writer = tf.summary.FileWriter("/Users/yangyibo/test/logs/test",sess.graph)

# 初始化变量
init= tf.global_variables_initializer()
sess.run(init)

for i in range(1000):
    #开始train,每次只取100张图片,免得数据太多训练太慢
    batch_xs, batch_ys = mnist.train.next_batch(50)
    # 丢弃百分之40 的数据
    sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 0.5})
    if i % 50 == 0:
        train_result = sess.run(merged, feed_dict={xs: batch_xs, ys: batch_ys, keep_prob: 1})
        test_result = sess.run(merged, feed_dict={xs: mnist.test.images, ys: mnist.test.labels, keep_prob: 1})
        train_writer.add_summary(train_result,i)
        test_writer.add_summary(test_result,i)

本文源码:
https://github.com/527515025/My-TensorFlow-tutorials/blob/master/tensorflow_10_dropout.py
欢迎 start

本文参考:莫烦老师的pyhon课程

猜你喜欢

转载自blog.csdn.net/u012373815/article/details/78525542
今日推荐