NCE(ノイズ対比推定)サンプリング法は、トレーニング速度を向上させ、ソースコード内の正のラベルの数が等しくなければならないという問題を解決します。

1。目的

分類タスクでは、分類カテゴリが数十万または数百万のオーダーである場合、最後の分類層の計算には非常に時間がかかります。モデルの計算の複雑さを軽減するために、各順方向計算では、ラベルパラメータのサンプリング部分wwwは計算に参加し、勾配が逆方向に計算されると、計算に関係する部分wwの一部のみが更新されます。w、すべての重みパラメータを更新せずにww毎回w、これにより、モデルのトレーニング速度を大幅に向上させることができます。

2テンソルフローソースコードの解釈

2.1関数の入力と出力

nn_impl.pyこれはNCE損失のテンソルフローソースコードです。次に、ソースコードを整理して説明します。最初に、次のコードに示すように、テンソルフローソースコードでのnce_lossの実装を見てみましょう。

def nce_loss(weights,
             biases,
             labels,
             inputs,
             num_sampled,
             num_classes,
             num_true=1,
             sampled_values=None,
             remove_accidental_hits=False,
             partition_strategy="mod",
             name="nce_loss"):
     #计算采样的labels和对应的logits(wx+b)值
	logits, labels = _compute_sampled_logits(
      weights=weights,
      biases=biases,
      labels=labels,
      inputs=inputs,
      num_sampled=num_sampled,
      num_classes=num_classes,
      num_true=num_true,
      sampled_values=sampled_values,
      subtract_log_q=True,
      remove_accidental_hits=remove_accidental_hits,
      partition_strategy=partition_strategy,
      name=name)
     # 交叉熵loss
	sampled_losses = sigmoid_cross_entropy_with_logits(
      labels=labels, logits=logits, name="sampled_losses")
     # 返回loss求和
	return _sum_rows(sampled_losses)

関数入力パラメーター

  • 重み:[num_classes、dim]、最後の分類層の重みパラメーターwww
  • バイアス:[num_classes]、最後の分類レイヤーのバイアスbbb
  • ラベル: [batch_size、num_true]、int64タイプ、各バッチの正のラベルインデックスidex 各サンプルの正のラベルの数は同じである必要があり、値はnum_trueです(これも実際のアプリケーションでは柔軟性のない部分です。後で改善される予定です計画)
  • 入力:[batch_size、dim]、入力分類層の特徴ベクトル
  • num_sampled:intタイプ、サンプルごとにランダムにサンプリングされた負のサンプルの数
  • num_classes:int型、分類されたラベルの総数
  • num_true:intタイプ、各サンプルのポジティブラベルの数(バッチ内のすべてのサンプルのポジティブラベルは同じである必要があります)
  • sampled_values:トリプル(サンプリングされた候補セット、ポジティブラベルの数、サンプリングされたラベルの数)であるカスタムサンプリングの候補セット。デフォルトはNoneで、log_uniform_candidate_samplerアダプターを使用します。
  • remove_accidental_hits:ブール型。正のラベルセットでサンプリングされたラベルを削除するかどうか。NCEの代わりに負のサンプリング損失を使用するには、「True」に設定します。
  • partition_strategy:2つのモード「mod」と「div」。デフォルトは「mod」です。詳細については、tf.nn.embedding_lookupを参照してください。
  • name:操作の名前

関数の戻り値
各サンプルの損失値に対応する、長さが[batch_size]の1次元ベクトル。

2.2コードの説明

次に、各関数の実装の詳細な分析を行います。上記から、nce_loss関数の下に_compute_sampled_logits、sigmoid_cross_entropy_with_logits、および_sum_rowsの3つの主要な関数があることがわかります。

_compute_sampled_logits関数


def _compute_sampled_logits(weights,
                            biases,
                            labels,
                            inputs,
                            num_sampled,
                            num_classes,
                            num_true=1,
                            sampled_values=None,
                            subtract_log_q=True,
                            div_flag=True,
                            remove_accidental_hits=False,
                            partition_strategy="mod",
                            name=None,
                            seed=None):

   #数据格式转
  if isinstance(weights, variables.PartitionedVariable):
    weights = list(weights)
  if not isinstance(weights, list):
    weights = [weights]
  # 数据格式转换,将label [batch_size, num_ture] 展开,得到一维的size
  with ops.name_scope(name, "compute_sampled_logits",
                      weights + [biases, inputs, labels]):
    if labels.dtype != dtypes.int64:
      labels = math_ops.cast(labels, dtypes.int64)
    labels_flat = array_ops.reshape(labels, [-1])
   #如果采样label不传入,则默认用log_unifrom_candidate_sampler采样器,生成采用的label
    if sampled_values is None:
        sampled_values = candidate_sampling_ops.log_uniform_candidate_sampler(
          true_classes=labels,
          num_true=num_true,
          num_sampled=num_sampled,
          unique=True,
          range_max=num_classes,
          seed=seed)
    # sampled:[num_sampled],true_expected_count:[batch_size,1]
    # sampled_expected_count: [num_sampled]
    # 采样的值不参与梯度更新,所以用stop_gradient标明
    sampled, true_expected_count, sampled_expected_count = (
        array_ops.stop_gradient(s) for s in sampled_values)
    sampled = math_ops.cast(sampled, dtypes.int64)
    # labels_flat:[batch_size * num_true],sampled: [num_sampled]
    #将正label和负label对应的索引合并到一起
    all_ids = array_ops.concat([labels_flat, sampled], 0)
    #通过索引all_ids从权重参数矩阵weights:[num_classes, dim],取出对应的权重参数,得到all_w
    all_w = embedding_ops.embedding_lookup(
        weights, all_ids,partition_strategy=partition_strategy)
    if all_w.dtype != inputs.dtype:
      all_w = math_ops.cast(all_w, inputs.dtype)
    # 抽离出正label w权重参数 true_w,和负label权重参数sampled_w
    #true_w :[batch_size * num_true, dim]
    # sampled_w: [num_sampled, dim], 一个batch里,每个样本的负label都是一样的
    true_w = array_ops.slice(all_w, [0, 0],
                             array_ops.stack(
                                 [array_ops.shape(labels_flat)[0], -1]))
    sampled_w = array_ops.slice(
        all_w, array_ops.stack([array_ops.shape(labels_flat)[0], 0]), [-1, -1])
    #在对应的负label上,计算wx+b,inputs: [batch_size, dim]
    # sampled_logits: [batch_size, num_sampled]
    sampled_logits = math_ops.matmul(inputs, sampled_w, transpose_b=True)
    # 与计算all_w一样,抽取偏移all_b
    all_b = embedding_ops.embedding_lookup(
        biases, all_ids, partition_strategy=partition_strategy)
    if all_b.dtype != inputs.dtype:
      all_b = math_ops.cast(all_b, inputs.dtype)
    # 抽离出正,负偏移b
    true_b = array_ops.slice(all_b, [0], array_ops.shape(labels_flat))
    sampled_b = array_ops.slice(all_b, array_ops.shape(labels_flat), [-1])
    # inputs: [batch_size, dim]
    # true_w: [batch_size * num_true, dim]
    # 计算wx+b,得到true_logits:[ batch_size, num_true]
    dim = array_ops.shape(true_w)[1:2]
    new_true_w_shape = array_ops.concat([[-1, num_true], dim], 0)
    # 做点乘,得到row_wise_dots: [batch_size, num_true, dim]
    row_wise_dots = math_ops.multiply(
        array_ops.expand_dims(inputs, 1),
        array_ops.reshape(true_w, new_true_w_shape))
    #reshape
    dots_as_matrix = array_ops.reshape(row_wise_dots,
                                       array_ops.concat([[-1], dim], 0))
    # 得到正label对应的logits值 [batch_size, num_true]                                                               
    true_logits = array_ops.reshape(_sum_rows(dots_as_matrix), [-1, num_true])
    # +b
    true_b = array_ops.reshape(true_b, [-1, num_true])
    true_logits += true_b
    sampled_logits += sampled_b
################## 此段代码去掉采样的label在正label里  ###########
    if remove_accidental_hits:
      acc_hits = candidate_sampling_ops.compute_accidental_hits(
          labels, sampled, num_true=num_true)
      acc_indices, acc_ids, acc_weights = acc_hits

      # This is how SparseToDense expects the indices.
      acc_indices_2d = array_ops.reshape(acc_indices, [-1, 1])
      acc_ids_2d_int32 = array_ops.reshape(
          math_ops.cast(acc_ids, dtypes.int32), [-1, 1])
      sparse_indices = array_ops.concat([acc_indices_2d, acc_ids_2d_int32], 1,"sparse_indices")
      # Create sampled_logits_shape = [batch_size, num_sampled]
      sampled_logits_shape = array_ops.concat(
          [array_ops.shape(labels)[:1],
           array_ops.expand_dims(num_sampled, 0)], 0)
      if sampled_logits.dtype != acc_weights.dtype:
        acc_weights = math_ops.cast(acc_weights, sampled_logits.dtype)
      sampled_logits += gen_sparse_ops.sparse_to_dense(
          sparse_indices,
          sampled_logits_shape,
          acc_weights,
          default_value=0.0,
          validate_indices=False)
############# 此段代码表示是logits否减去-log(true_expected_count) #######
    if subtract_log_q:
      # Subtract log of Q(l), prior probability that l appears in sampled.
      true_logits -= math_ops.log(true_expected_count)
      sampled_logits -= math_ops.log(sampled_expected_count)
	#将正负logits concat到一起,得到out_logits: [batch_size, num_true+num_sampled]
	out_logits = array_ops.concat([true_logits, sampled_logits], 1)
	# 标签labels,正label为1/num_true,保证总和为1,负label标签为0
	# out_labels: [batch_size, num_ture+num_sampled]
    out_labels = array_ops.concat([
        array_ops.ones_like(true_logits) / num_true,
        array_ops.zeros_like(sampled_logits) ], 1)
	return out_logits, out_labels

sigmoid_cross_entropy_with_logits関数
は言うまでもありません。クロスエントロピー損失は、複数のラベルがあるため、マルチラベル分類であるため、sigmoidを使用します。注意すべき点の1つは、関数パラメーターlogitsで渡される値が元のwx +であるということです。 b値。シグモイド計算は関数で操作されます。

_sum_rows関数

def _sum_rows(x):
  #该函数的类似tf.reduce_sum(x,1)操作
  #官方给出用这样计算的理由是,计算梯度效率更高
  cols = array_ops.shape(x)[1]
  ones_shape = array_ops.stack([cols, 1])
  ones = array_ops.ones(ones_shape, x.dtype)
  # x:[batch_size, num_true+num_sampled]
  # ones: [num_true+num_sampled, 1]
  #x和ones两个矩阵相乘,得到[batch_size,1],再reshape [batch_size]
  return array_ops.reshape(math_ops.matmul(x, ones), [-1])

2.3デメリット

tensorflowソースコードから、各入力バッチの正のラベルの数は同じである必要があり、その数はnum_trueであるため、モデルを通常どおりトレーニングする場合、各バッチの正のラベルは同じである必要がありますが、実際にはアプリケーション、特にマルチラベル分類の場合、各サンプルのポジティブラベルの数に一貫性がありません。マルチタスクタスクでは、バッチが複数のタスクのラベルラベルで一貫しているとは限りません。

3一貫性のない数のポジティブラベルの解決策

上記の欠点を考慮して、テスト済みで実行可能な次のソリューションを試してください。

3.1パッドラベルをネガラベルとして追加する

コアアイデアサンプルを生成するとき、各サンプルのラベルの長さは一様にnum_trueです。不十分な場合は、インデックス0(パッドを表す)のラベルで埋めます。損失を計算するときは、パッドカテゴリを負に対応させます。ラベル
変更されたコードは主に次のとおりです。

#修改前函数
def _compute_sampled_logits(...):
	...
	out_logits = array_ops.concat([true_logits, sampled_logits], 1)
	# 对应的源代码生成label过程
	out_labels = array_ops.concat([
    	array_ops.ones_like(true_logits) / num_true,
    	array_ops.zeros_like(sampled_logits)], 1)
    return out_logits, out_labels

#修改后函数
def _compute_sampled_logits(...):
	...
	out_logits = array_ops.concat([true_logits, sampled_logits], 1)
	# 生成mask矩阵,其中真实的正label元素为1, 填充pad label为0
    mask = tf.cast(tf.not_equal(labels, 0), tf.float32)
    # 将pad的label都为负label 0
    true_y = array_ops.ones_like(true_logits) * mask 
    # 然后用div_flag控制是否需要对每个样本的label除以每个样本的个数
    # 这里动态的计算每个样本的真实label数量,因为每个样本pad的个数不一致
    if div_flag:
        dynamic_num_true = tf.reduce_sum(tf.sign(labels), reduction_indices=1)
        dynamic_num_true = tf.cast(tf.expand_dims(dynamic_num_true, -1), tf.float32)
        true_y = tf.div(true_y, dynamic_num_true)
    # 将正label和负label组合,得到out_labels返回
    out_labels = array_ops.concat([
        true_y,
        array_ops.zeros_like(sampled_logits)], 1)
 	return out_logits, out_labels  

4リファレンス

ノイズコントラスト推定:
正規化されていない統計モデルの新しい推定原理

おすすめ

転載: blog.csdn.net/BGoodHabit/article/details/110420952