Gaussian Prototype Network に関する原論文の高品質な翻訳

論文アドレス: Omniglot での少数ショット学習のためのガウス プロトタイプ ネットワーク

まとめ

オムニグロット データセットの K ショット分類のための新しいアーキテクチャを提案します。プロトタイプ ネットワークに基づいて、ガウス プロトタイプ ネットワークと呼ばれるそのアーキテクチャを拡張します。プロトタイプ ネットワークは、画像と埋め込みベクトルの間のマッピングを学習し、それらのクラスタリングを分類に使用します。私たちのモデルでは、エンコーダー出力の一部は、埋め込み点に関する信頼領域推定値として解釈され、ガウス共分散行列として表現されます。次に、私たちのネットワークは、個々のデータ点の不確実性を重みとして使用して、埋め込み空間上で方向およびカテゴリに依存する距離メトリックを構築します。ガウス プロトタイプ ネットワークは、同じ数のパラメーターを持つバニラ プロトタイプ ネットワークよりも優れたアーキテクチャであることを示します。5 ウェイおよび 20 ウェイ体制における 1 ショットおよび 5 ショット分類の Omniglot データセットに関する最先端のパフォーマンスを報告します (5 ショット 5 ウェイについては、以前の最先端のパフォーマンスと同等です)。アート)。パフォーマンスをさらに向上させるために、トレーニング セット内の画像のサブセットを人工的にダウンサンプリングすることを検討します。したがって、ガウス プロトタイプ ネットワークは、実世界のアプリケーションで一般的な、均一性が低くノイズの多いデータセットでより優れたパフォーマンスを発揮する可能性があると仮説を立てています。

1 はじめに

1.1 少数ショット学習

人間は、単一または少数の例で新しいオブジェクト カテゴリを認識することを学習できます。これは、手書き文字認識 [1] からモーション コントロール [2]、高レベルの概念の取得 [3] まで、幅広い活動で実証されています。この動作を機械で再現することが、少数学習を研究する動機になります。

パラメトリック深層学習は、大量のデータがある環境で良好にパフォーマンスを発揮します。一般に、深層学習モデルは非常に高い関数的表現力と機能を備えており、教師付きレジーム下でのゆっくりとした反復トレーニングに依存しています。したがって、トレーニングの目的はデータ セットの一般的な構造を把握することであるため、トレーニング セット内の特定の例の影響は最小限になります。これにより、トレーニング後に新しいカテゴリを迅速に導入することができなくなります。[4]

対照的に、少数ショット学習では、新しいデータに非常に迅速に適応する必要があります。特に、k ショット分類とは、トレーニング中に表示されなかったクラスを k 個のラベル付きサンプルを使用して学習する必要がある方式を指します。k 近傍 (kNN) などのノンパラメトリック モデルはオーバーフィットしませんが、そのパフォーマンスは距離メトリックの選択に大きく依存します。[5] は、パラメトリック モデルとノンパラメトリック モデルのアーキテクチャ、およびトレーニング条件とテスト条件のマッチングを組み合わせており、最近では k ショット分類で良好なパフォーマンスを示しています。

1.2 ガウス プロトタイプ ネットワーク

この論文では、[6] で使用されたプロトタイプ ネットワークに基づいて新しいアーキテクチャを開発し、それを Omniglot データセット [3] でトレーニングおよびテストします。Vanilla プロトタイプ ネットワークは、画像を埋め込みベクトルにマッピングし、分類にそのクラスタリングを使用します。彼らは、イメージのバッチをサポート イメージとクエリ イメージに分割し、サポート セットの埋め込みベクトルを使用して、クラス プロトタイプ (特定のクラスの典型的な埋め込みベクトル) を定義しました。これらへの近接性が分類に使用されます。

ガウス プロトタイプ ネットワークと呼ばれる私たちのモデルは、画像を埋め込みベクトルと画像品質の推定値にマッピングします。埋め込みベクトルとともに、その周囲の信頼領域が予測され、ガウス共分散行列によって特徴付けられます。ガウス プロトタイプ ネットワークは、埋め込み空間上で方向およびカテゴリに依存する距離メトリックを構築することを学習します。私たちのモデルは、バニラのプロトタイプ ネットワークと比較して、追加のトレーニング可能なパラメーターを使用する好ましい方法であることを示します。

私たちの目標は、モデルが単一のデータ ポイントに対する信頼性を表現できるようにすることで、より良い結果が得られることを示すことです。また、個々のデータ ポイントの重み付けがパフォーマンスにとって重要となる可能性がある、ノイズが多く不均一な現実世界のデータセットに対するアプローチのスケーラビリティを調査するために、データセットの一部を意図的に破損する実験も行いました。

私たちの知る限り、Omniglot データセットでは、5 ウェイおよび 20 ウェイ体制における 1 ショットおよび 5 ショット分類の最先端のパフォーマンスを報告しています (5 ショット 5 ウェイについては、以前の最先端の Advanced パフォーマンス同等品と比べても遜色ありません)。[3] ダウンサンプリングされたデータに対するモデルの応答を研究することにより、低品質で不均一なデータセットではその利点がより大きくなる可能性があると仮説を立てています。

この記事は以下のような構成になっています。関連する作業についてはセクション 2 で説明します。次に、セクション 3 で私たちの方法を紹介します。エピソード的なトレーニング計画もここで紹介されます。オムニグロット データセットについてはセクション 4 で説明し、実験についてはセクション 5 で説明します。最後に、セクション 6 で結論を示します。

2 関連作品

k 最近傍 (kNN) などのノンパラメトリック モデルは、これまでに見たことのないクラスを含めることができるため、少数の分類器の理想的な候補です。ただし、距離メトリックの選択には非常に敏感です。[5] 入力空間の距離 (生のピクセル値など) を直接使用すると、画像クラスとそのピクセル間の関係が非常に非線形であるため、高い精度が得られません。

[7]、[8]、[9]、および [10] で実証されているように、単純な修正、つまり kNN 分類に使用されるメトリックの埋め込みを学習すると、良い結果が得られます。[11] では、マッチング ネットワークを使用する方法が提案されており、基本的に画像のペア間の距離メトリックを学習します。このアプローチの注目すべき特徴は、各ミニバッチ (エピソードと呼ばれる) がクラスの数と各クラスの例の数をサブサンプリングすることによってデータプア テストを模倣しようとするトレーニング スキームです。このような方法により、少数の分類のパフォーマンスが向上することが示されています。[11] したがって、私たちもこのアプローチを採用します。

最近、データセットで直接学習する代わりに、エピソードを入力として与えられた確率的分類器の更新を予測するように LSTM [13] をトレーニングすることが提案されています [12]。このアプローチはメタ学習と呼ばれます。[14] および [15] に示されているように、メタ学習は Omniglot [3] で高い精度を達成しています。時間畳み込みに基づくタスク診断メタ学習器は、[16] で提案されています。パラメトリック手法とノンパラメトリック手法の組み合わせは、最近、少数の学習において最も成功しています。[6][17][18]

私たちの方法は特に画像分類を対象としており、メタ学習を通じてこの問題を解決しようとするものではありません。私たちは、[6] で提案されたモデルに基づいて構築します。このモデルは、画像を埋め込みベクトルにマッピングし、そのクラスタリングを分類に使用します。私たちのモデルの新しい特徴は、学習された画像依存の共分散行列を介して個々のデータ ポイントの信頼度を予測することです。これにより、画像を投影できる、より豊かな埋め込み空間を構築できます。次に、方向およびクラスに関連した距離メトリックに基づくクラスタリングが分類に使用されます。

3つの方法

この論文では、まず [6] で説明されているプロトタイプ ネットワークを検討します。私たちはこのアーキテクチャをガウス プロトタイプ ネットワークと呼ぶものに拡張し、ガウス共分散行列によって特徴付けられる埋め込みベクトルとその周囲の信頼領域を予測することで、モデルが個々のデータ ポイント (画像) の品質を反映できるようにします。

架空のプロトタイプ ネットワークは、画像を埋め込みベクトルにマッピングするエンコーダーで構成されます。バッチには、利用可能なトレーニング クラスのサブセットが含まれています。各反復では、各カテゴリの画像がサポート画像とクエリ画像にランダムに分割されます。イメージベースの埋め込みは、クラスのプロトタイプ (そのクラスの典型的な埋め込みベクトル) を定義するために使用されます。クラス プロトタイプへのクエリ画像の埋め込みの近さが分類に使用されます。

バニラ プロトタイプ ネットワークとガウス プロトタイプ ネットワークのエンコーダー アーキテクチャに違いはありません。主な違いは、エンコーダ出力がどのように解釈され使用されるか、および埋め込み空間のメトリックがどのように構築されるかです。ガウス ネットワークでは、エンコーダー出力の一部を使用して、埋め込みベクトルに関する共分散行列を構築します。これにより、モデルが予測力と個々のデータ ポイントの品質を反映できるようになります。

3.1 エンコーダ

明示的な最終完全接続層を使用せずに、多層畳み込みニューラル ネットワークを使用して、画像を高次元ユークリッド ベクトルにエンコードします。[6] で説明されている Nothingness のプロトタイプ ネットワークの場合、エンコーダーは画像 I を取得し、それをベクトル ~x に変換する関数です。つまり、H と W は入力画像の高さと幅、C はそのチャネルです
ここに画像の説明を挿入します
。量。D はベクトル空間の埋め込み次元であり、モデルのハイパーパラメータです。W はエンコーダのトレーニング可能な重みです。

ガウス プロトタイプ ネットワークの場合、エンコーダーの出力は、埋め込みベクトル~ x∈R Dと共分散行列 Σ∈R D×Dの関連成分の連結です。ここ
ここに画像の説明を挿入します
で、DS共分散行列の予測成分の次元です。

ガウス プロトタイプ ネットワークの 3 つのバリエーションを調査します。

  1. 半径共分散推定値D S =1 の場合、各画像はその埋め込みベクトルの信頼区間のサイズを記述するために1 つの実数 s raw ∈ R 1のみを生成します。したがって、共分散行列は Σ=diag(σ,σ,…,σ) の形式になります。ここで、σ は元のエンコーダー出力 s rawから計算されます。したがって、信頼度推定値には方向性がありません。Omniglot データセットでは、このアプローチが追加パラメーターの最も効率的な使用であることが証明されました [3]。この優先順位はデータセット固有のものである可能性があり、均質性の低いデータセットではより複雑な共分散推定が優先される可能性が高いと考えられます。
  2. 対角共分散推定D S =D、共分散推定の次元は、埋め込み空間の次元と同じです。 sraw∈R D は、埋め込みベクトルの周りの信頼区間のサイズを記述するために各画像上で生成されます。したがって、共分散行列は Σ = diag ( → σ)の形式になります。ここで、 σ は元のエンコーダー出力 sraw から計算されます。これにより、信頼楕円体は常に埋め込み空間軸と軸が揃ったままになりますが、ネットワークはデータ ポイントの方向信頼度を表現できるようになります。
  3. 完全な共分散推定完全な共分散行列がデータ ポイントごとに出力されます。このアプローチは、特定のタスクに対して不必要に複雑であることが判明したため、それ以上の検討は行われませんでした。

入力として、次元 28×28×1 のダウンサンプリングされたグレースケール オムニグロット画像を使用します。2 × 2 の最大アンサンブルを備えた 4 層 CNN アーキテクチャでは、形状 1 × 1 × (D + D S ) のボリュームが得られます。ここで、埋め込み次元 D に共分散行列 D の関連部分を加えたものが最終的なフィルター量と等しくなります。TensorFlow の同じパディングとストライド 1 を使用しています。私たちのフィルターの空間範囲は 3×3 です。最後の層は完全に接続された層と同等です。

2 つのエンコーダ アーキテクチャを使用しています。1) 小規模なアーキテクチャ、2) 大きなアーキテクチャ。この小さなアーキテクチャは、[6] で使用されているものに対応しており、これを使用して、以前の最先端の結果に対して独自の実験を検証します。大規模なアーキテクチャを使用して、モデル容量の増加が精度に及ぼす影響を観察しました。基本的な構成要素として、式 3 の層シーケンスを使用します。
ここに画像の説明を挿入します
どちらのアーキテクチャも、積み重ねられた 4 つのブロックで構成されています。アーキテクチャの詳細は以下の通りです。

  1. 小さな構造3×3 フィルター、フィルターの数は [64, 64, 64, D] です ([64, 64, 64, D+1] は半径ガウス モデル、[64, 64, 64, 2D] は対角ガウス モデルです)モデル)。調査された埋め込み空間の次元は、D=32、64、および 128 です。
  2. 大きな構造物3×3 フィルター、フィルター数は [128, 256, 512, D] (半径ガウス モデルは [128, 256, 512, D+1]、対角ガウス モデルは [128, 256, 512, 2D ]) 。調査された埋め込み空間の寸法は D=128、256、512 です。

エンコーダーの生の共分散行列出力を実際の共分散行列に変換する 4 つの異なる方法を検討しました。主に共分散行列 S=Σ-1 の逆値を扱っているため、それを直接予測します。raw エンコーダー出力の関連部分を S rawとします。以下のような方法。

  • S = 1 + Softplus(S raw )。ここで、softplus(x) = log (1 + e x )。これはコンポーネントごとに適用されます。Softplus(x)>0 であるため、S>1 が保証され、エンコーダはデータ ポイントの重要性を下げることしかできません。S の値も上記の制限を受けません。どちらの方法もトレーニングに有益であることが証明されています。私たちの最良のモデルは、初期トレーニングにこの体制を使用します。
  • S = 1 + シグモイド (S raw )。 sigmoid(x) = 1/ (1 + e -x ) であり、コンポーネントごとに適用されます。sigmoid(x)>0 (S>1 が保証される) であるため、エンコーダはデータ ポイントの重要性を低くすることしかできません。S < 2 であるため、S の値は上から制限され、エンコーダはさらに制約されます。
  • S = 1 + 4 シグモイド (S raw )、したがって、1 < S < 5 となります。これを使用して、共分散推定ドメインのサイズがパフォーマンスに及ぼす影響を調査します。
  • S = オフセット + スケール × ソフトプラス (S raw /div)。ここで、オフセット、スケール、および div は 1.0 に初期化されており、トレーニング可能です。私たちの最良のモデルは、最初のアプローチよりも柔軟でデータ駆動型であるため、トレーニング後にこの体制を使用します。

3.2 時折のトレーニング

プロトタイプ モデルの重要なコンポーネントは、[6] で説明されているエピソードトレーニング体制です。トレーニング プロセス中に、Nc クラスのサブセットがトレーニング セット内のクラスの総数から (置換なしで) 選択されます。これらのクラスごとに、N 個のサポート インスタンスと Nq 個のクエリ インスタンスがランダムに選択されます。サポートされているエンコード埋め込みの例は、埋め込み空間内の特定のクラス プロトタイプの位置を定義するために使用されます。クエリ インスタンスとクラス プロトタイプの位置の間の距離は、クエリ インスタンスを分類し、損失を計算するために使用されます。ガウス プロトタイプ ネットワークの場合、各埋め込み点の共分散も推定されます。このプロセスの概略図を図 1 に示します。
ここに画像の説明を挿入します

図 1: ガウス プロトタイプ ネットワークの機能図。エンコーダは、画像を埋め込み空間内のベクトル (黒丸) にマッピングします。各画像は共分散行列 (暗い楕円) も出力します。サポート イメージは、クラス固有のプロトタイプ (星形) と共分散行列 (明るい色の楕円形) を定義するために使用されます。中心点とエンコードされたクエリ画像の間の距離は、クラスの合計共分散によって補正され、クエリ画像の分類に使用されます。特定のクエリ ポイントまでの距離が灰色の破線で表示されます。

ガウス プロトタイプ ネットワークの場合、共分散行列の半径または対角が埋め込みベクトルとともに出力されます (より正確には、元の形式で、詳細についてはセクション 3.1 を参照してください)。これらは、特定のクラスのサポート ポイントに対応する埋め込みベクトルに重みを付け、そのクラスの全体的な共分散行列を計算するために使用されます。次に、クラスのプロトタイプ c からクエリ点 i までの距離 d c (i)は次のように計算されます。
ここに画像の説明を挿入します
ここで p cはクラス c の中心点またはプロトタイプ、S c = Σ -1 cはクラス c の中心点またはプロトタイプです。共分散行列の逆行列。したがって、ガウス プロトタイプ ネットワークは、埋め込み空間でクラスおよび方向に依存する距離測定を学習できます。トレーニングの速度と精度は、損失を構築するために距離がどのように使用されるかに大きく依存することがわかりました。最良の選択肢は線形ユークリッド距離、つまり d c (i) を使用することであると結論付けます。使用される損失関数の具体的な形式は、アルゴリズム 1 に示されています。図 2 は、ガウス プロトタイプ ネットワークの埋め込み空間図を示しています。付録の図 10 と図 11 は、トレーニング中の埋め込み空間のサンプルを示しています。これは、分類のための類似した文字のクラスタリングを示しています。
ここに画像の説明を挿入します

図 2: ガウス プロトタイプ ネットワークの埋め込み空間を示す図。画像はエンコーダによってその埋め込みベクトル (ダークスポット) にマッピングされます。その共分散行列 (暗い楕円) もエンコーダーによって出力されます。次に、各クラスの全体的な共分散行列 (大きな明るい色の楕円) と、クラスのプロトタイプ (星印) が計算されます。クラス共分散行列は、クエリ ポイント (灰色で表示) の距離メトリックをローカルに変更するために使用されます。

セクション 3.1 で要約したように、共分散行列が対角になる設定を検討します。半径の場合、S = sI、ここで I は単位行列、s∈R1 は各画像の生のエンコーダー出力から計算されます。対角線の場合、S = diag ( s)、ここで s も各画像の生のエンコーダー出力から計算されます。

3.3 クラスを定義する

プロトタイプ ネットワークの重要な部分は、特定のカテゴリに利用可能なサポート ポイントからカテゴリ プロトタイプを作成することです。私たちは、個々のサポート インスタンスの埋め込みベクトルの分散加重線形結合を解決策として提案します。クラス c に、埋め込みベクトル x c iとしてエンコードされるサポート イメージ I iと、対角が s c iである共分散行列 S c iの逆行列があるとします。プロトタイプ、つまりクラスの中心点は次のように定義されます。ここで、 はコンポーネントごとの乗算を表し、除算もコンポーネントごとに行われます。次に、準共分散行列の対角が次のように計算されます。これは、各点を中心とするガウスを全体的な準ガウスに最適化することに相当するため、ネットワークの名前は「ガウス」です。s の要素は実際には 1/σ 2です。したがって、式 5 および 6 は、例を1/ σ2で重み付けすることに対応します。これにより、ネットワークはクラスを定義する上でそれほど重要ではないサンプルの重みを下げることができるため、アーキテクチャがノイズの多い、不均等な、またはその他の不完全なデータセットにより適したものになります。
ここに画像の説明を挿入します

ここに画像の説明を挿入します

私たちのネットワークが訓練される方法であるワンショット体制の場合、各カテゴリを定義する単一のラベル ベクトル x cがあります。これは、ベクトル自体がクラスのプロトタイプとなり、その共分散行列がクラスに継承されることを意味します。共分散は、クエリ点からの距離を変更する役割を果たします。完全なアルゴリズムはアルゴリズム 1 で説明されています。
ここに画像の説明を挿入します

3.4 評価モデル

テスト セット上のモデルの精度を推定するために、k∈[1, …19] の範囲内の各サポート ポイントの数 Ns = k を使用してテスト セット全体を分類します。したがって、Omniglot はカテゴリごとに 20 個の例を提供するため、特定の k のクエリ ポイントの数は Nq = 20 - Ns になります。次に、精度が集計され、モデル トレーニングの特定の段階での k ショット分類精度が k の関数として決定されました。指定された検証セットを使用しなかったため、トレーニング精度が最も高い 5 つのテスト結果を考慮し、その平均と標準偏差を計算することで公平性を確保しました。これを行うことで、テスト セットでの最適化結果が回避され、結果として得られる精度の誤差限界がさらに得られます。既存の文献と直接比較するために、5 ウェイおよび 20 ウェイのテスト分類でモデルを評価します。

4つのデータセット

Omniglot データセットを使用しました。[3] Omniglot には、50 のアルファベット (現実および架空) からの 1623 の文字クラスと、各文字クラスの 20 の手書き、グレースケール、105 × 105 ピクセルの例が含まれています。それらを 28×28×1 にダウンサンプリングし、平均を減算して反転します。私たちは、[3] で提案され、[6] で使用されている 30 個のトレーニング文字と 20 個のテスト文字という推奨される分割方法を使用しました。トレーニング セットには 964 の一意の文字クラスが含まれ、テスト セットには 659 が含まれます。トレーニング データセットとテスト データセットの間にカテゴリの重複はありません。ハイパーパラメーターを微調整せず、トレーニング精度に基づいて最高のパフォーマンスのモデルのみを選択したため、別の検証セットは使用しませんでした (セクション 3.4 を参照)。

クラスの数を拡張するには、各文字を 90 °、180 °、および 270 °回転させてデータセットを増やし、各回転を新しい文字クラス自体として定義します。同じアプローチは [11] と [6] でも使用されています。強化された文字の例を図 3 に示します。これにより授業数は4倍に増加します。したがって、トレーニング セットには合計 77,120 個の画像が含まれ、テスト セットには合計 52,720 個の画像が含まれます。回転強化により、回転対称のキャラクターは引き続き複数のカテゴリに定義されます。仮定上の完全な分類器でも、「O」と回転した「O​​」などの文字を区別できないため、100% の精度を達成することはできません。
ここに画像の説明を挿入します

図 3: ローテーションによってクラスの数を増やす例。元のキャラクター (左側) を 90°、180°、270° 回転させます。各ローテーションは新しいクラスとして定義されます。これによりクラスの数が増加しますが、対称文字の縮退も生じます。

トレーニングを改善し、文字の共分散を予測するガウス ネットワークの機能を活用するために、一部の実験ではトレーニング セットの一部を意図的にダウンサンプリングしました。詳細についてはセクション 5 を参照してください。私たちの結果は、オムニグロット データセットが単純すぎるため、共分散行列を推定するガウス ネットワークの機能を十分に活用できないことを示しています。私たちは、実際のアプリケーションでは一般的な状況である、個々のデータ ポイントの品質が異なる異種データセットにおいて、この方法がより大きな利点を発揮すると仮説を立てています。

5つの実験

私たちは、Omniglot データセットに対して広範ないくつかの学習実験を実施しています。ガウス プロトタイプ ネットワークについては、さまざまな埋め込み空間次元、共分散行列の生成方法、およびエンコーダ機能を調査しました (詳細についてはセクション 3.1 を参照)。また、それらをバニラのプロトタイプ ネットワークと比較したところ、特に追加パラメーターを使用する最も効率的な方法は各埋め込み点 (セクション 3.1 の半径法) に対して単一の数値を予測することであるため、ガウス バリアントが有利であることが結果からわかりました。一般に、エンコーダのサイズ (セクション 3 で説明するように、小さいものと大きいもの)、ガウス/バニラ プロトタイプ ネットワークとの比較、距離メトリック (コサイン、√L 2 、L 2 、および L 2 2 ) 度数調査ます。ガウス ネットワークの共分散行列の自由度 (半径と対角推定、セクション 3.1 を参照)、および埋め込み空間の次元。また、入力データセットのサブセットをダウンサンプリングすることで、ネットワークに共分散推定の使用を促すことも検討し、これによりパフォーマンスが (k > 1) 倍向上することがわかりました。

初期学習率 2×10 -3の Adam オプティマイザーを使用します2000 イベント ≈ 30 エポックごとに学習率を半分にします。すべてのモデルは TensorFlow で実装され、Google Cloud の単一の NVidia K80 GPU で実行されます。各モデルのトレーニングには 1 日もかかりません。

トレーニング中に Nc=60 カテゴリ (60 方向分類) でモデルをトレーニングし、Nct=20 カテゴリ (20 方向分類) をテストしました。最高のパフォーマンスのモデルを得るために、最終的な Nct=5 (5 方向) 分類テストも実行して、結果を文献と比較しました。サポート ポイントの数を制限すると分類精度が向上することがわかったので、トレーニング中、ミニバッチに存在する各クラスは Ns = 1 サポート ポイントで構成されます。これは、トレーニング体制とテスト体制の一致として直感的に理解できます。カテゴリごとに残りの Nq = 20 - Ns = 19 個の画像がクエリ ポイントとして使用されます。

実験の詳細な結果を表 1 にまとめます。エンコーダの生の共分散出力から共分散行列を推定する 4 つの方法を検討します。詳細については、セクション 3.1 を参照してください。
ここに画像の説明を挿入します

表 1: 大規模なエンコーダ構造 (3×3 フィルター、4 層、フィルター数 = 128,256,512,-) のテスト結果。最終精度に対する共分散行列と埋め込み空間の次元の影響を比較しています。(a、b、c、d) には、生のエンコーダー出力を共分散行列に変換するさまざまな方法が含まれます。共分散の半径推定により、エンコーダの出力に次元が追加されます。対角推定により、エンコーダ出力の数が 2 倍になります。したがって、256 の埋め込み次元と対角共分散を持つ大規模なガウス ネットワークには、512 の仮想ネットワークと同じ数のパラメーターがあります。半径推定値には 1 つの次元が追加されているため、同じ埋め込み次元の架空のモデルに匹敵します。破損した列は、トレーニング セットがトレーニング中に意図的に部分的にダウンサンプリングされたことを示しています。

また、共分散推定が不必要に複雑でない限り、共分散推定としてエンコーダー出力を使用する方が、追加の埋め込み次元として同じ数のパラメーターを使用するよりも有利であることも検証しました。これは半径推定 (つまり、埋め込みベクトルごとに 1 つの実数) に当てはまりますが、対角推定はパフォーマンスに役立たないように見えます (パラメーターの数を等しく保つ)。この効果を図 4 と表 1 に示します。最もパフォーマンスの高いモデルは、最初に破損していないデータセットで 220 エポックにわたってトレーニングされました。次にトレーニングを継続し、100 エポックで画像の 1.5% を 24×24 にダウンサンプリングし、1.0% を 20×20 にダウンサンプリングし、0.5% を 16×16 にダウンサンプリングします。次に、20 エポックに対して 23×23 への 1.5% ダウンサンプリング、17×17 への 1.0% ダウンサンプリング、および 10 エポックに対して 23×23 への 1.0% ダウンサンプリングを使用します。これらの選択は任意であり、最適化されていません。データセットを意図的に破壊すると、共分散推定の使用が促進され、表 1 と図 5 に示すように (k > 1) の結果が増加しました。このセクションでは、Omniglot データセットが高品質すぎて、私たちのアプローチにとっては単純すぎるテストベッドであることを示します。トレーニング損失曲線を図 6 に示します。反復の関数としてのトレーニングとテストの精度も図 7 に示します。
ここに画像の説明を挿入します

図 4: 追加パラメータを割り当てる 2 つの方法の比較。追加のパラメーターを割り当てて、埋め込み空間 (半径) の次元を増やすか、より正確な共分散推定 (対角) を行います。半径推定 (埋め込みベクトルごとに 1 つの追加の実数を含む) は、対角推定よりも優れており、同じ数のパラメーターを持つ架空のプロトタイプ ネットワークよりも優れています。


ここに画像の説明を挿入します

図 5: K ショット テストの精度に対するトレーニング セットの一部のダウンサンプリングの影響。意図的に破損したデータでトレーニングされたバージョンは、共分散推定をより適切に活用することを学習するため、未変更のデータでトレーニングされたバージョンよりも優れています。


ここに画像の説明を挿入します

図 6: 反復の関数としての損失。黄色の縦線は、学習率が半分になる場所を示します。学習率を半分にすることによる有益な効果は、最初に確認できます。赤い部分は、部分的にダウンサンプリングされたトレーニング セットでのトレーニングに対応するため、損失が高くなります。


ここに画像の説明を挿入します

図 7: トレーニング精度とテスト精度の比較。この図は、大規模なガウス プロトタイプ ネットワーク (半径共分散推定) のトレーニング精度 (60 方向分類) を示し、それを 1 ショットおよび 5 ショット テストのパフォーマンス (20 方向分類) と比較しています。また、その結果を現在の最先端テクノロジーと比較します。[6]

表 2 にまとめたように、小規模なアーキテクチャで検証実験を実施し、[6] と同等の結果を得ました。この表は、Ns > 1 の条件下でトレーニングすると、つまりクラスを定義するためにより多くのデータ ポイントを使用すると、パフォーマンスが低下することも示しています。図 8 は、より大きなモデルにおけるより高い容量の効果を示しています。私たちのモデルと文献の結果との比較を表 3 に示します。私たちの知る限り、私たちのモデルは、Omniglot の 5 ウェイおよび 20 ウェイのテスト時間分類の両方で、最先端の 1 ショットおよび 5 ショットの結果を上回っています。特に 5 ウェイ 5 ショット分類では、完璧に近いパフォーマンス (99.73 ± 0.02 %) を達成したため、少数の学習アルゴリズムをさらに開発するには、より複雑なデータセットが必要であると結論付けられました。
ここに画像の説明を挿入します

表 2: 小規模アーキテクチャを使用した検証実験の結果。20通りの技術レベルは1発96.0%、5発98.9%。Ns はトレーニング中の各カテゴリのサポート ポイントの数です。すべてのトレーニングは、Nc=60 (60 方向分類) 体制の下で実行されます。ガウス プロトタイプ モデルの場合、σ∈S は推定された共分散行列の次元を表します。


ここに画像の説明を挿入します

図 8: 損失に対するモデル容量の影響。モデルが大きいほど、トレーニングが速くなり、全体的な損失が小さくなります。黄色の縦線は、学習率が半分になる場所を示します。


ここに画像の説明を挿入します

表 3: 他の論文と比較した、私たちの実験の最良の結果。すべてのトレーニングは、Nc=60 (60 方向分類) 体制の下で実行されます。私たちの知る限り、私たちのモデルは、1 ショットと 5 ショットの 20 方向分類、および 1 ショット 5 方向分類の両方で統計的に最先端のパフォーマンスを達成しています。5 ウェイの 5 つのケースで、当社のパフォーマンスは現在の最先端のものと同等です。

5.1 共分散推定の使用法

ガウス プロトタイプ ネットワークは、個々の埋め込まれた画像の共分散を予測する能力と、それによってそれらの重み付けを下げる可能性があるため、バニラ バージョンよりも優れているという仮説を検証するために、部分的にダウンサンプリングされた画像でトレーニングされた最高のパフォーマンスのネットワークのパフォーマンスを研究しました。トレーニングセット 予測値の分布。データの一部を意図的にダウンサンプリングし、その結果得られる共分散分布を調査しました。

ダウンサンプリングされた画像では、平均と分散が変化します。私たちのエンコーダーには各ブロックにバッチ正規化層があるため (詳細は式 3 を参照)、元の出力の特定の値の意味は現在のバッチに応じて変化します。私たちのモデルはバッチ正規化を使用してトレーニングされているため、共分散を調べるためにバッチ正規化をオフにすると、無関係な結果が得られます。

妥協のないデータセットの場合、共分散推定値の大部分は同じです。これは、ダウンサンプリングを通じて減損が人為的に導入された場合でも当てはまります。ただし、最後の層のバッチ正規化の影響により、分布はシフトされます。個々の逆共分散の意味をよりよく表現するために、最も頻度の高い値が互いに一致するようにヒストグラムを調整します。このアプローチは、最も支配的な値が元の出力 0 に対応し、それと異なる値のみが分類に影響を与えるため便利です。結果を図 9 に示します。
ここに画像の説明を挿入します

図 9: 元のデータセットと部分的にダウンサンプリングされたバージョンの予測共分散。ガウス ネットワークは、黄色の分布のより重い裾からわかるように、より高い S を予測することによって、侵害された例を軽減することを学習します。前縁と値の差のみが分類に影響するため、分布はまとめて配置されます。

6 結論

この論文では、少数派の写真分類のためのガウス プロトタイプ ネットワーク (プロトタイプ ネットワーク [6] に基づいた改良されたアーキテクチャ) を提案します。私たちは Omniglot データセットでモデルをテストし、共分散行列推定値を生成しベクトルを埋め込むさまざまな方法を調査しました。同等のパラメータ数では、ガウス プロトタイプ ネットワークがバニラ プロトタイプ ネットワークよりも優れているため、アーキテクチャの選択が有益であることを示します。埋め込みベクトルに対して単一の実数を推定する方が、対角行列または完全な共分散行列を推定するよりも効果的であることがわかりました。私たちは、品質が低く、均質性が低いデータセットは、より複雑な共分散行列推定を好む可能性があると仮説を立てています。[6] とは反対に、ネットワークが 1 ショット方式でトレーニングされた場合に最良の結果が得られることがわかりました。次に、モデルをスケールアップし、5 ウェイと 20 ウェイの両方のテスト体制で 1 ショットおよび 5 ショットの分類で最先端のパフォーマンスを達成することができました (5 ショット 5 ウェイについては、同等のものを比較しました)最先端のものへ)。トレーニング データセットのサンプリング レートを人為的に下げて、ネットワークが共分散推定を最大限に活用できるようにすることで、(特に (k>1) ショット分類の) 精度を向上させることができました。特に 5 方向分類の場合、私たちの結果は完璧なパフォーマンスに非常に近いため、少数派の写真分類のさらなる開発は、オムニグロットよりも複雑なデータセットに焦点を当てる必要があると結論付けています。私たちは、エンベディングとその不確実性を学習する機能は、実際のアプリケーションではよくある低品質のデータセットでより有益であると仮説を立てています。そこでは、一部のデータ ポイントの重み付けを下げることが、忠実な分類の鍵となる可能性があります。Omniglot でのダウンサンプリング実験はこれを実証しています。

おすすめ

転載: blog.csdn.net/qq_56039091/article/details/127794318