ST-MGAT: 交通予測のための時空間マルチヘッド グラフ アテンション ネットワーク

ST-MGAT: 交通予測のための時空間マルチヘッド グラフ アテンション ネットワーク

まとめ:

        グラフ ニューラル ネットワーク (gnn) は、グラフの表現を学習できるため、ますます注目を集めています。交通予測は典型的なグラフ表現学習タスクですが、交通における複雑な時空間関係をモデル化することは困難です。従来のスペクトル手法は、グラフのラプラシアン行列に依存する固有分解に基づいてフィルターを取得します。ただし、これらの方法はグラフ畳み込みニューラル ネットワーク上で高価な行列演算を必要とし、空間依存性の問題を解決するには十分ではありません。本論文は,交通予測問題に対処するための新しいタイプのグラフニューラルネットワークである時空間マルチヘッドグラフアテンションネットワーク(ST-MGAT)を提案する。グラフ上で直接畳み込みを構築します。近傍ノードの特性とエッジの重みを考慮して、新しいノード表現を生成します。より具体的には、2 つの主要なモジュールがあります: i) 動的な時間相関を捕捉する時間畳み込みブロック、ii) ノード間の動的な空間関係を捕捉する注目を集めるネットワーク グラフ。実験結果は、私たちのモデルが短期、中期、長期の高速道路交通予測において現在の最先端の方法を 13% 改善することを示しています。

グラフ畳み込みネットワーク:       

        空間ドメイン法: 空間法では、グラフの構造を空間的に考慮します。つまり、ターゲット ノードと他のノードの間の幾何学的関係が考慮されます。課題は、ノードの新しい特徴を生成することです。これは、隣接するノードの特徴を収集して集約することによって実現されます。本論文ではグラフ畳み込みネットワークの空間領域法であるマルチヘッドグラフアテンションネットワークを採用する。エンコーダ/デコーダまたはスペクトル領域メソッドを使用する代わりに、アテンション メカニズムを使用してグラフ上に直接畳み込みを構築します。

時空モデル:

        一般に、時空間モデルは、畳み込みニューラル ネットワーク (CNN) とリカレント ニューラル ネットワーク (RNN) の 2 つの方法に分類できます。ただし、RNN ベースの手法は、長いシーケンスに対しては効果がなく、勾配が爆発する可能性があるという課題に直面しています。トラフィックデータの時間的関係をシミュレートするために、ゲート機構を備えた拡張畳み込み構造を採用しています。

グラフ畳み込みにおける注意メカニズム:

        アテンションベースのグラフ畳み込みの主なアイデアは、エッジ情報を持つノードを集約することによって新しいノード表現を生成することです。ここで、注目係数は、グラフ内の各ノードの隣接ノードに対する相互重要度です。この論文では、グラフを直接畳み込み、アテンション メカニズムを通じてネットワークを貴重な情報に集中させます。

問題の定義:

交通ネットワークをグラフとして扱います。そのタスクは、次のいくつかのタイム ステップでノードの特性を予測することです。双方向車線を 2 つの車線に抽象化します。車線をグラフのエッジとして扱い、車線上に道路感知器をグラフ上の点として配置し、速度、流れ、占有率などの交通状況の尺度もグラフのノードの特徴として選択します。ネットワークは、無向グラフ G = (V, E, A) として合理的に定義できます。ここで、V はノードの有限セット、E はエッジのセットです。隣接行列は と表されます。ここで言及したグラフは無向グラフであり、この記事の目的は、将来のある時点でのグラフ上のすべてのノードの特性を予測することであることを強調しておく必要があります。グラフ上のノードは道路上の検出器から選択されます。さらに、検出器によって生成されたデータはノードの特徴です。 X はノードの単一の信号ではなく、ノード全体を含むグラフ信号であることを強調しておく必要があります。交通流予測は、過去の測定値 (速度、量、占有率など) を使用して、次の S タイム スライスの交通流を予測することです。入力、 出力、N は観測点データ、F は各ノードの特性、T は入力タイム ステップ、P は出力タイムステップが長い。 は、時刻 t におけるノードのベクトルを表します。グラフ上のすべてのノードと履歴シーケンス X を考慮して、次の P タイム ステップのトラフィック速度を予測します。

 方法:

        このセクションでは、フレームワークの 2 つの主要な部分を紹介します。空間レイヤーは、アテンション係数を使用してノード フィーチャを集約することによって新しいノード表現を生成するグラフ アテンション ネットワーク (GAT) によって構築されます。時間層は、時間特性を捕捉して時間の浪費を防ぐゲート機構を備えた拡張畳み込み構造で構成されます。レイヤーはスタックされ、レイヤーに正規化を適用することで過剰適合を防止しながら、予測精度を向上させます。最後に、モデルは全結合層を追加することにより、次の t タイム ステップで n ノードの出力を生成します。次に、私たちのアプローチの枠組みを概説します。

フレーム:

        グラフネットワークの枠組みを図に示します。一般に、私たちのモデルは、ゲート機構と空間ベースのグラフ アテンション コンボリューション ブロックを備えた拡張コンボリューションと、それに続く出力用の全結合層で構成されます。

        入力データは X∈RN×T×F です。ここで、N はノードの数、T はタイム ステップ、F は各ノードの特性です。時間層の準備は、2 次元の畳み込みを通じて入力データの特徴拡張を実現することです。 2 つの同一のアンラップ畳み込み層は、それぞれアンラップ カーネル サイズが 1、2、および 4 の特徴拡張データを受け入れます。アダマール積は、2 つの並列畳み込み層 (ゲート ユニット) の要素に適用されます。続いて、グラフ アテンション コンボリューション (GAT) の 2 つのレイヤーが重ね合わされ、シーケンシャル レイヤーの結果が処理されます。同時に、未処理のデータとグラフ畳み込みによって処理されたデータを融合するために、残りのネットワーク層がセットアップされます。構造は前述したように一層であり、複数の層が積層されている。グラフ アテンション畳み込み層の入力は特徴 F∈RN×D であることを強調しておく必要があります。ここで、N はノードの数、D は入力特徴のサイズです。

        出力は です。ここで、Dout はフィーチャの新しいサイズです。ゲート層の出力は で、T はタイム ステップです。この記事では、3 次元データ Xgated を 2 次元データに変換します。ここでです。この手法は、トラフィックデータの時間情報をノードの特性に融合します。

        要するに、時間層は、時間特性を捕捉して時間の浪費を防ぐゲート付き時間畳み込みブロックで構成されます。空間レイヤーは、アテンション係数を使用してノード フィーチャを集約して新しいノード表現を生成するグラフ アテンション ネットワークによって構築されます。レイヤーはスタックされ、レイヤーに正規化を適用することで過剰適合を防止しながら、予測精度を向上させます。最後に、モデルは全結合層を追加することにより、次の t タイム ステップで n ノードの出力を生成します。

 

 ST-MGAT アーキテクチャ。入力は X∈RN×T×F です。ここで、N はノードの数、T は時間ステップ、F は各ノードの特徴です。出力は Yout∈RN×T で、t タイム ステップでの N ノードの予測速度を表します。入力データを処理するための線形アプローチは、元のデータ特徴の次元を増やすことです。各チャネルは、1 × 1 のコンボリューション カーネルを使用した複数の 2 次元コンボリューションを使用します。フィルタリングブロックは2次元コンボリューションを採用しています。ゲートブロック、残差ブロック、集計ブロックは一次元畳み込みを採用しています。私たちのモデルには b 層があり、それぞれに 2 つのグラフ畳み込みブロックが積み重ねられています。

グラフ アテンション レイヤー: はノード i の特性を表し、 はノード i の更新を表します。隠れた機能。 N はノードの数、F はフィーチャの数です。

 グラフの畳み込み:

        グラフ アテンション ネットワークは、アテンション メカニズムを使用して隣接ノードの特徴の重み付き合計を実行する方法を提案します。特徴ノードの重みはノードの特性に完全に依存し、グラフの構造とは何の関係もありません。この方法は、スペクトル グラフ畳み込みネットワークのボトルネックを克服し、異なる近傍に異なる学習重みを割り当てるのが簡単です。

        アテンションベースのグラフ畳み込みネットワーク (GAT) とスペクトルベースのグラフ畳み込みネットワーク (GCN) の主な違いは、1 ホップ離れた隣接ノードの特徴表現を収集して要約する方法です。頂点特徴間の相関関係がモデルにうまく統合されるため、GAT はある程度まで強力になります。基本的な利点は、計算がノードごとに行われることです。各操作では、グラフ上のすべての頂点をループしてノードの特徴を集約する必要があります。頂点ごとの操作は、ラプラシアン行列の制約が排除されることを意味します。他のアテンション メカニズムと同様に、GAT の計算は、アテンション係数の計算と重み付けされた特徴の集約の 2 つのステップに分かれています。

 

         式からわかるように、共有パラメータ W の線形マッピングは、特徴強調で一般的に使用される特徴である頂点特徴まで次元を拡張します。この方法では、頂点の変換特徴量を連結し、高次元特徴量を a(*) で実数 eij にマッピングします。この機能は、単層フィードフォワード ニューラル ネットワークを通じて実装されます。頂点 i と j の間の相関は、学習可能なパラメーター W とマッピング関数 a(*) を通じて学習されます。

        相関係数の正規化:

 

         ノードの受信端のアテンションスコアはソフトマックスで正規化し、上記の相関係数を使用します。特徴は、計算された注意係数に従って重み付けされ、集計されます。 は、GAT が各頂点 i を融合した後の新しい特徴出力であり、σ(*) は活性化関数です。

 

        畳み込みニューラル ネットワークのマルチコアと同様に、マルチヘッド アテンション メカニズムを使用してモデルの機能を強化し、トレーニング プロセスを安定させます。各アテンションヘッドには独自のパラメータがあります。 K は注目する頭の数です。 h は GAT の新機能であり、各頂点の近傍情報を融合する活性化関数です。中間層には連結を使用し、最後の層には平均化を使用することをお勧めします。

 

        GAT における重要な学習パラメータは W と a(*) です。上記の頂点ごとの計算方法により、これら 2 つのパラメーターは頂点の特性にのみ関連し、グラフの構造とは何の関係もありません。したがって、グラフの構造を変更しても、テスト タスクの GAT にはほとんど影響がありません。

グラフ構造の並列計算:

        新しいアプローチは、ループの各バッチ中にグラフの畳み込みを実行する代わりに、グラフを 1 つの大きなグラフに結合することです。たとえば、グラフ g には n 個のノードと e 個のエッジがあり、畳み込み, の入力となります。

 

        ここで、fin はノードの特性を表し、G() は畳み込み演算を表し、fout はノードの新しい特性を表します。上記の操作をバッチ サイズの時間ループで実行する必要はありません。ただし、グラフのバッチ処理は課題に直面しています。グラフはまばらな場合もあれば、大きい場合もあります。対照的に、グラフを大きなグラフにバッチ化し、畳み込みを並列処理します。図に示すように、バッチの出力は依然としてグラフです。これは、基本的なグラフに対する操作がバッチを返す際にも引き続き有効であることを意味します。

 

 時間畳み込み層:

        図に示すように、タイムライン上のゲートを備えた拡張畳み込み構造を使用して、時間相関を抽出します。時間畳み込み層は 1 次元の畳み込みを設定し、その後にゲート線形ユニットが続きます。

 拡張された畳み込み。カーネル サイズを 1 および 2 に設定    

        拡張畳み込みの利点は、プーリング情報を失わずに受容野を増加させるため、各畳み込み出力にはより広範囲の情報が含まれることです。アンロールド畳み込みは、画像セグメンテーション、音声合成、機械翻訳など、画像がグローバル情報または音声を必要とし、テキストが長いシーケンス情報を必要とする問題に適用できます。拡張畳み込みの操作により、特徴マップの相対的な空間位置が保存されます。これは、モデルが受容野を改善し、履歴情報を考慮に入れることを意味します。拡張因果畳み込み手法は次のようになります。

式中、F は 2 次元シーケンス (画像)、s はドメイン、K はカーネル関数、t はドメイン、L は拡張係数、P は拡張畳み込みのドメインです。上の式は 1 次元の場合と変わりません。​ 

        カーネル サイズが k で、拡張畳み込みのストライドが r の場合、受信フィールドは k * k から k + (r-1) * (k-1) に変化し、後半部分は挿入される数値はゼロです。畳み込みを展開した後、ゲート線形単位が適用されて、層を通過する情報が決定されます。図に示すように、入力を次のように設定します。コンボリューション カーネルは入力 X をにマッピングします。

 

 ここで、 および は要素のアダマール積を表します。 σ() はシグモイド ゲート関数であり、どの情報が次のステップに伝播されるかによって決まります。

 実験:

データセット:

METR-LA には、2012 年 3 月 1 日から 2012 年 6 月 30 日までの 4 か月間、ロサンゼルスの 207 台の検知器からの交通情報が含まれています。 PEMS-BAY には、2017 年 1 月 1 日から 2017 年 5 月 31 日までの 6 か月間、ベイエリアの 325 台の検知器の交通情報が含まれています。検出器によって記録されたデータは 5 分ごとにグループに分割されます。この実験では、1 日を 288 の時間枠に分割します。

 ベースライン:

        私たちのモデルを他のいくつかの古典的なモデルと比較します。結果を公平にするために、一部のモデル (STGCN など) が複製され、同じデータセットに適用されます。一部の結果は論文内のデータを直接参照しています (GaAN、GWaveNet など)。読者は公開コードでいくつかのベースライン モデルを見つけることができます。

ARIMA: 自己回帰統合移動平均

LSTM:長期短期記憶

GaAN: グラフ畳み込みブロックとリカレント ニューラル ネットワークを使用したエンコーダー/デコーダー ネットワークの構築

DCRNN: 拡散畳み込み再帰ニューラル ネットワーク

WaveNet: 時系列タスク用のネットワーク

Graph WaveNet: 時系列タスクのためのグラフ畳み込みネットワーク

STGCN: 時空間グラフ畳み込みネットワーク

結果と分析:

 ベースラインと比較します:

畳み込み比較あり/なし:

 

 話し合う:

        ​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​​彼女自身によってマルチコンポーネントフュージョン技術 [27] からインスピレーションを得たこの技術は、過去の交通データにおける最近、毎日、および毎週の周期的な時間相関を捕捉します。過去 1 日と 1 週間の同じノードの履歴を追加しましたが、役に立ちませんでした。タイム ステップの数が増加するにつれて、交通量の予測はわずかに改善される傾向があります (水平線は 12 ステップ)。最後に、現時点に基づく次の期間の交通流予測は、主に過去数時間の交通流によって制御されることが検証される。同様に、損失関数を MAE から RMSE に変更すると、RMSE は改善されますが、他の指標は改善されません。入力タイム ステップを変更し、グラフ畳み込み隠れチャネルを追加すると、予測精度がわずかに向上しますが、モデル パラメーターが増加し、消費時間が増加します。

結論と今後の課題:

        本稿では、アテンションベースのグラフ畳み込みネットワークを含む新しい交通流予測モデルST-MGATを提案する。私たちの知る限りでは、スペクトルベースの方法ではなく空間ベースの方法を交通流予測タスクに初めて適用し、モデルの汎化能力を強化しました。実験の結果、このモデルは畳み込みベースの方法を上回り、変化する道路状況への対応の適応性が向上することが示されています。今後の作業では、モデルを一般的なグラフに適用し、気象条件などの補足情報を追加して、モデルの精度をさらに向上させます。

おすすめ

転載: blog.csdn.net/weixin_53187018/article/details/130588292