【Tensorflow】使用get_variable和safe_embedding_lookup_sparse进行embedding并添加L2正则时的注意事项

使用get_variable和safe_embedding_lookup_sparse进行embedding并添加L2正则时的注意事项

safe_embedding_lookup_sparse能够实现“Lookup embedding results, accounting for invalid IDs and empty features”,配合get_variable进行初始化embedding权值,
使用方法为:

from tensorflow.contrib.layers import safe_embedding_lookup_sparse
weight = tf.get_variable(name="linear_weight".format(name),shape=[vocabulary_size, 1],
initializer=tf.glorot_normal_initializer())

当使用这个方法进行LookUp求embedding之后,如果想要添加变量的L2正则损失,有些同学可能会这样做:

l2_loss = tf.contrib.layers.l2_regularizer(0.001)(tf.trainable_variables())

其实这样的做法时错误的。因为,这样会把所有trainable_variables的变量都添加到了l2正则损失里面,但是实际上我们需要的是“用到的那些trainable_variables”。拿embedding来说,embedding我们初始化了一个很大的权值矩阵,但是实际上我们只使用了我们LookUpTable查到的那部分权值,因此只需要把这部分权值添加进去就行了。
至于为什么上面的做法会把全部的权值都添加进l2,我们看一下tf.get_variabletrainable参数描述:

'''
trainable: If `True` also add the variable to the graph collection
        `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). `trainable`
        defaults to `True` unless `synchronization` is set to `ON_READ`.
'''

那该怎么解决呢?

embedding = safe_embedding_lookup_sparse(weight, ids, combiner="mean")
tf.add_to_collection(tf.GraphKeys.WEIGHTS,embedding)

首先使用safe_embedding_lookup_sparse得到embedding向量,然后,将该向量添加进tf.GraphKeys.WEIGHTS这个collection里,然后:

l2_regular = tf.contrib.layers.l2_regularizer(0.001)
regularization_loss = tf.contrib.layers.apply_regularization(l2_regular)

即,定义一个l2_regularizer,然后调用tf.contrib.layers.apply_regularization就可以得到实际使用embedding权值的l2正则损失了。注意tf.contrib.layers.apply_regularization(regularizer, weights_list=None)函数其实有两个参数,第一个是正则化方法,第二个是想要执行正则化方法的变量列表,如果为None,则默认取tf.GraphKeys.WEIGHTS中的weight,这就是我们将需要正则化的变量加入该集合的原因,也可以加入别的集合,只要在函数中指明要正则化的变量集合名字即可。

发布了97 篇原创文章 · 获赞 55 · 访问量 13万+

猜你喜欢

转载自blog.csdn.net/voidfaceless/article/details/103137947
今日推荐