TensorFlow1实现正则化

TensorFlow1中实现正则化的方法有多种,下面依次介绍

1.将loss包含的项手动加入collection

这种方法在创建层时要用最原始的api,如tf.nn.conv2d,因为需要手动创建变量并加入collection

#为全连接层创建变量w,计算其L2损失,并加入名为"losses"的collection中
w = tf.get_variable(name='w', shape=[384, 192], initializer=tf.truncated_normal_initializer(stddev=0.05))
regular = tf.multiply(tf.nn.l2_loss(w), 0.005, name='regular')
tf.add_to_collection('losses', regular)
#将交叉熵也加入"losses"collection中,并全部相加得到l2正则化后的loss
cross_entropy = tf.losses.softmax_cross_entropy(onehot_labels=record_labels, logits=y_predict)
tf.add_to_collection('losses', cross_entropy)
loss = tf.add_n(tf.get_collection('losses'))

2.使用get_variable中的regularizer参数

regularizer参数可设定该变量的正则化形式,并将其加入TensorFlow自带的一个正则化collection中,在variable_scope中也有regularizer参数,操作相同。
在高级一点的网络层api中,如tf.layers.conv2d,可使用kernel_regularizer参数定义变量的正则化形式

#定义一个regularizer
#regularizer = tf.contrib.layers.l2_regularizer(scale=0.1)(弃用)
regularizer=keras.regularizers.l2(0.001)
#创建变量时使用改regularizer
weights = tf.get_variable(
        name="weights",
        regularizer=regularizer,
        ...
    )
#得到保存正则化项的collection,并与原loss相加
reg_variables = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
loss = cross_entropy + tf.add_n(reg_variables)
#collection中的所有数据都是在一个list中,也可以通过另一种函数直接得到保存正则项的collection中的所有值之和
loss = cross_entropy + tf.losses.get_regularization_loss()

3.在Keras的层中使用kernel_regularizer参数

之前的正则化方式都需要手动将原损失和正则项加在一起,而更高级的Keras api中,在kernel_regularizer参数上定义正则化方式后,无需手动与原loss相加

model = keras.models.Sequential([
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation=tf.nn.relu, input_shape=(NUM_WORDS,)),
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
                       activation=tf.nn.relu),
    keras.layers.Dense(1, activation=tf.nn.sigmoid)
])

4.将网络中所有参数加上相同的正则系数

l2_loss = weight_decay * tf.add_n(
     [tf.nn.l2_loss(tf.cast(v, tf.float32)) for v in tf.trainable_variables()])
loss = cross_entropy + l2_loss

猜你喜欢

转载自blog.csdn.net/qq_43221336/article/details/106331623
今日推荐