GNNとグラフ信号処理
概要
GNNは, 信号処理の観点でローパスフィルタの効果を持っている. そして, GNNが通常のMLPに比べて高精度となるようなデータでは, 低周波成分に(タスクに関して)有用な情報が多く含まれており, 高周波成分には無駄な情報が多いようなケースになっているように見受けれられる. また, 実際にタスクの精度を様々なフィルタで比較すると, 信号処理の観点での性質が反映された結果を与えている. GNNの最も大きな役割(通常のMLPとの違い)は, ローパスフィルタとしての性質の影響なのではないかと考えられる. 逆に言うと, ローパスフィルタとしての性質が生かせるようなデータが, GNNの効果が最も期待出来るものなのではないかと考えられる.
内容
- はじめに
- GNNとはどのようなものだったか?
- 信号処理のいくつか
- GNNのローパスフィルタとしての役割
- まとめ
1. はじめに
Graph Neural Network(GNN)は, グラフ分析の領域で近年State of the Artの精度を記録しているが, その表現能力に関しては, いくつか疑問が持たれています. 例えば, GNNのさきがけの論文であるT.Kipf & M.Welling(ICLR2017)では, 層を増やしても精度が向上しないことが報告されています:
(図)T.Kipf & M.Welling(ICLR2017)における実験結果 (ノードのクラス分類タスク) Testの部分を見ると, 2, 3層での精度が最も良く, それ以降は減少していることが分かります.
これは, 通常のニューラルネットワークとは少し様相が異なるように思えます. これに対し, 最近2つの論文,
- [1] Simplifying Graph Convolutional Networks, F. Wu et al, ICML2019
- [2] Revisiting Graph Neural Networks: All We Have is Low-Pass Filters, Hoang NT and T.Maehara, arXiv:1902.07153
でグラフ信号処理の観点から, GNNの特性に関する議論がされています. 今回は, GNNが, グラフ信号処理の観点からどのように考えられるのかを簡単に見ていくことにします.
Note:
ここで記載している内容は, 厳密な議論ではないことに注意してください. 詳細な議論は論文の方の参照をお願いします. また, 記載に誤りがあるかもしれません. もし,間違いなどがあれば,ご指摘いただけると大変助かります.よろしくお願いします.
2. GNNとはどのようなものだったか?
前提
今回は「重みなし無向グラフ」を前提とします.
重みなし, つまり隣接行列は, エッジがある場合は , そうでない場合は とします. また, 有向グラフの場合, 隣接行列が対称行列とならないため, 理論的な考察が難しくなります.
「重みなし無向グラフ」以外のグラフは, 今回の記事で扱う範囲からは外れることに注意してください.
GNN概観
(参考) http://tkipf.github.io/misc/SlidesCambridge.pdf
基本的には, 自分自身も含めて, 隣接している特徴量を集約させるところが肝で, それ以降は, 通常のニューラルネットワークと同様に重みパラーメータ を掛けて, 活性化関数 を掛けるということにまとめられます. ここで, はノード の隣接ノードの集合を表しています.
一応, よく間違われる注意点をまとめておくと,
グラフの構造自体は, 更新では変わらない. 変わるのは, ノードの特徴量のみ.
隣接ノードの情報を集約するが, この集約部分は学習パラメータを持たない. (ある一定の規則で集約させるのが基本, エッジ重みやノードが持つエッジ数など, ただし, GAT [P.Veličković et al, ICLR2018] など発展的なものは学習パラメータを加えている)
集約した後は、単に重み行列を掛けて、活性化関数を掛けるだけで通常のニューラルネットワークと同じ
自分自身の寄与(1項目)と, 隣接ノードの寄与に対する重みを分けるやり方もある (GraphSAGE [W.L.Hamilton et al, NIPS2017] など)
隣接特徴量の足し上げの行列表現
上記で, 隣接ノードに対する特徴量を集約するといったが, これはどのように行列演算として表されるでしょうか?
各ノードの特徴量を とし, これを下記のようにノード数 分だけ並べて行列として表します:
$$ \begin{align} H = \begin{pmatrix} h_{1}^{T} \\ h_{2}^{T} \\ \vdots \\ h_{N}^{T} \end{pmatrix} \end{align} $$
そこに, 隣接行列 を掛けると,
$$ \begin{align} AH = \begin{pmatrix} a_{11}h_{1}^{T}+a_{12}h_{2}^{T}+\cdots+a_{1N}h_{N}^{T} \\ a_{21}h_{1}^{T}+a_{22}h_{2}^{T}+\cdots+a_{2N}h_{N}^{T} \\ \vdots \\ a_{N1}h_{1}^{T}+a_{N2}h_{2}^{T}+\cdots+a_{NN}h_{N}^{T} \end{pmatrix} \end{align} $$
隣接行列は, ノード と がエッジを持つ場合, 成分はノンゼロであるような行列なので,
$$ \begin{align} (AH)_{i,\,:} = \sum_{j \in \mathcal{N}(i)} a_{ij}\,h_{j}^{T}, \quad \text{where}~~\mathcal{N}(i)\,\text{: ノード $i$ の隣接ノードの集合} \end{align} $$
隣接しているノード に関する特徴量の和になります. つまり, 隣接行列 を掛けることにより, 隣接ノードの特徴量の集約が実現出来ることが分かります.
GNNの行列表現
さらに, GNNでは, (1) 次数正規化, (2) セルフループの追加を行います.
(1) 次数正規化:
隣接ノードの特徴量を集約するのですが, その際に, エッジ に関して, それぞれの端ノードの次数で正規化します:
$$ \begin{align} A \to D^{-1/2}A D^{-1/2}, \quad \text{where}~~ D=\text{diag}(\sum_{i=1}^{N}a_{1i}, \sum_{i=1}^{N}a_{2i}, \cdots, \sum_{i=1}^{N}a_{Ni}) \end{align} $$
これは, 成分表示で表すと,
$$ \begin{align} (D^{-1/2}A D^{-1/2})_{ij} = \frac{1}{\sqrt{\sum_{u=1}^{N}a_{iu}\sum_{v=1}^{N}a_{iv}}}\,\,a_{ij} \end{align} $$
となります.
(2) セルフループの追加:
隣接ノードの特徴量を集約しますが, それよりも自分自身の特徴量の寄与の方が重要であると考えられます. そのため,
$$ \begin{align} A \to I + A \end{align} $$
隣接行列に, 単位行列を加えることで, セルフループの寄与を加えます.
(1)と(2)を行うことで, 結局,
となります. ただし, 実はこれではうまくいかないことが分かっており, GNNでは, renormalization trick (T.Kipf & M.Welling(ICLR2017)) と呼ばれる正規化をした下記が使われます:
上記の2つを, 行成分だけ書いてみると,
これは, 隣接する部分は正規化して, 自分自身の寄与は最大値で集約させるような効果を示しています.
それに対し, renormalization trick の方は,
$$ \begin{align} \big{\{}\,(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})\,H \, \big{\}}_{i,\, :} &= \big{\{} \, (\tilde{D}^{-1} + \tilde{D}^{-1/2}A\tilde{D}^{-1/2})\,H \, \big{\}}_{ij} \\ &= \frac{1}{(d_{i}+1)}\,h_{i}^{T} + \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{(d_{i} +1)(d_{j}+1)}}\,h_{j}^{T} \end{align} $$
これだけだと, この trick の効果は分かりづらいと思います. これをグラフ信号処理の観点で, どのような効果があるのかを見ていくのが, この記事の本題となります.
3. 信号処理のいくつか
GNNをグラフ信号処理の観点で考える前に, 信号処理の部分に関して, ここで必要となる知識だけ簡単にまとめておきます.
この辺りは, Qiitaの記事,
- Qiita グラフ畳み込み再考, @cotton-gluon, https://qiita.com/cotton-gluon/items/5c4e2f9c2c8a120863fa
がとても分かりやすく, 今回の記事も多く参考にしています.
フィルタ
グラフ上の信号処理に入る前に, 通常の信号処理の基本的なところをおさらいします. ただし, ここでは時間変化する信号は扱いません.
まず, 信号処理の重要なキーワードの一つとして, 信号の特徴を抽出する「フィルタ」があります. 数学的に書くと, 信号 に対して, フィルタを通した後の信号 は,
と表されます. ここで, は窓関数と呼ばれ, フィルタの性質を表します. 例えば, 矩形窓と呼ばれる次のような窓関数を考えます,
これを作用させると, $$ \begin{align} \tilde{f}(x) = \frac{1}{2W} \int_{-W}^{W} f(s)\, ds \end{align} $$
となり, ある一定幅における平均を抽出するフィルタを表していることが分かります.
もう一つ代表的な例として, ガウシアンフィルタがあります.
このフィルタの特徴としては, スケール に対して, 大きな幅で振動する緩やかな波の場合は, より強調された信号が抽出されます. 一方で, スケール に比べて, 短い幅で激しく振動する波の場合, 平滑化された波が抽出されます. つまり, 低周波成分は強く, 高周波成分を弱める, 「ローパスフィルタ」の役割をしています.
(図) Qiita グラフ畳み込み再考, https://qiita.com/cotton-gluon/items/5c4e2f9c2c8a120863fa
フーリエ変換とラプラシアン
一般に, 信号 は, と の様々な周波数の波に分解することが出来ます(フーリエ変換).
例えば, 一つの 波に, ガウシアンフィルタを作用させてみると,
$$ \begin{align} \tilde{f}(x) &= \int_{-\infty}^{\infty} w(s) f(x-s)\,ds \\ &= \frac{1}{\sqrt{2\pi \text{σ}^{2}}} \int_{-\infty}^{\infty} \,e^{-s^{2}/2 \text{σ}^{2}}\, \cos [\omega (x-s)]\,ds \\ &= \cdots \\ &= e^{-\frac{\text{σ}^{2}}{2}\omega^{2}} \,\cos(\omega x) \end{align} $$ (この計算では, オイラーの公式やガウス積分の複素版を使っています)
ここから, ガウシアンフィルタは, 波の形は変えずに, 振幅を減衰する作用を与えていることが分かります. 特に, 周波数 が大きいほど減衰の大きさも大きくなることが分かります. 波に関しても同様の結果が得られます.
また, このフィルタの作用は,
$$ \begin{align} \tilde{f}(x) = &= e^{-\frac{\text{σ}^{2}}{2}\omega^{2}} \,\cos(\omega x) \\ &=\exp\Big{[} -\frac{\text{σ}^{2}}{2} \underbrace{\Big{(} -\frac{d^{2}}{dx^{2}} \Big{)}}_{=\Delta} \Big{]}\, \cos(\omega x) \end{align} $$
として, 微分作用素, ラプラシアン の観点で書き直すことが出来ます. つまり, ラプラシアンと信号の周波数が結びついていることが分かります.
そして, ここから,
ということも分かります. つまり, フーリエ変換の一つ一つの波の成分は, ラプラシアンに対する固有関数となっており, また, その固有値は周波数の2乗になっていることが分かります. フーリエ変換は, ラプラシアンという微分作用素の固有関数による展開として捉えることが出来ます.
グラフ上での信号処理
では, グラフ上の信号処理に関して考えていきます. グラフ構造上で, 各ノードが何か信号値を持っているような状況を想定します.
(参考)グラフ信号処理のすゝめ, https://www.jstage.jst.go.jp/article/essfr/8/1/8_15/_pdf (赤い線がグラフ構造, 青い線が各ノード上の信号値を表す)
このようにグラフ上での信号を考える場合, フーリエ変換はどのように表すことが出来るでしょうか?
先ほどの通常の信号処理で見てきたように, ラプラシアンを導入することによって, フーリエ変換を定義出来ると考えられます.
では, グラフ上での信号の微分はどのように導入出来るでしょうか? ここでは, 直感的な説明となりますが,
に沿って説明します.
次のようなグラフ信号を例として考えてみます:
ここで, 次のような接続行列(Incidence Matrix) を考えます,
これは, ノード数 エッジ数の行列で, 各エッジがあるところに対応するノードがあるところに値を持ちます. この時, ノード番号が小さいものの値を , 大きいものの値を とします. このような行列を導入すると,
となり, 各エッジごとの信号値の差分が得られます. つまり, この は微分作用素として見なすことが出来そうです. すると, ラプラシアン は, 2階微分作用素(勾配の発散)なので,
$$ \begin{align} \Delta = K K^{T} &= \begin{pmatrix} 1 & 0 & 0 & 0 \\ -1 & 1 & 1 & 0 \\ 0 & -1 & 0 & 0 \\ 0 & 0 & -1 & 1 \\ 0 & 0 & 0 & -1 \end{pmatrix} \begin{pmatrix} 1 & -1 & 0 & 0 & 0 \\ 0 & 1 & -1 & 0 & 0 \\ 0 & 1 & 0 & -1 & 0 \\ 0 & 0 & 0 & 1 & -1 \\ \end{pmatrix} \\ &= \begin{pmatrix} 1 & -1 & 0 & 0 & 0 \\ -1 & 3 & -1 & -1 & 0 \\ 0 & -1 & 1 & 0 & 0 \\ 0 & -1 & 0 & 2 & -1 \\ 0 & 0 & 0 & -1 & 1 \end{pmatrix} \end{align} $$
というように表されます. これは, グラフラプラシアンと呼ばれ, 先ほどの信号処理の観点から, この行列の固有値と固有ベクトルを求めることで, グラフ上の各周波数成分の波が分かります.
このグラフラプラシアンは, 対角成分を見てみると, 各ノードのエッジ数の合計値が入っていることが分かります(]). また, 各行ごとに非対角成分を見てみると, エッジを持っている成分に が入っていることが分かります. つまり,
という関係になっていることが分かります. また, 色々面白い性質を持っています. 特に, 正規化されたグラフラプラシアン,
の固有値は, ]となります(最小値 , 最大値 ). この時, 固有値「」 は, 「低周波 高周波」に対応することになります.
例えば, 固有値 に対応する固有ベクトルは, (連結グラフの場合) 全ての成分が のベクトルになります. これは, グラフ上で変化しない波に対応するため, 低周波の波として見なすことが出来ます.
4. GNNのローパスフィルタとしての役割
ようやく本題に入ります. GNNをグラフ信号処理の観点で考えます. ここでは, 文献 [1] のセクション3の議論に沿って説明します.
まず, 最初に導入したGNNに関して考えていきます.
ここで, を信号値とし, をフィルタと見なした時のフィルタの効果に関して考えます. 周波数ごとの解析を行うために, このフィルタを, 正規化グラフラプラシアン の観点で書き直す. すると,
と表せます. 信号値 を周波数ごとの波で分解したとすると, 各周波数成分の波に対して, の作用は, 固有値として各周波数(の2乗)を返します. つまり, 各周波数ごとの波に対して, このフィルタの効果は, が掛かることに対応します.
このフィルタを, Coraデータセット(論文引用ネットワーク)*を用いて表してみると,
(図) 文献 [1](与えられたネットワークデータに対して, グラフラプラシアンを求め, その固有値をそれぞれ求めてプロット)
となります. 低周波部分( に近い成分)が増幅され, 高周波部分( に近い成分)が減衰するローパスフィルタになっていることが分かります. ただし, 低周波成分は, 次数 が高くなると(層を増やしていく場合), 極端に信号値が増幅され過ぎてしまうことが分かります.
これに対して, 次に, 単純に正規化したフィルタ を考えてみます. この場合, となり, 同様に図で表すと,
先ほどとは異なり, 低周波成分のフィルタの最大値は となり, 層を増やしていっても極端な増幅はなくなることが分かります. 一方で, 高周波成分を見てみると, 偶数次では増幅され, 奇数次では負の値になってしまったりとローパスフィルタとしての機能は失われてしまっていることが分かります.
ここで, renormalization trick として導入したフィルタ に関して考えてみます. このフィルタは, 正規化グラフラプラシアンでなく, セルフループ付きの正規化グラフラプラシアン の観点で表すことができ,
となります. ここで, の固有値を考えると, 実は最大固有値が より小さくなるという事実があります(文献 [1] の付録参照). すると, このフィルタの効果は,
となり, 最大固有値が小さくなったことから, 高周波成分でのおかしな振る舞いが改善されて, より良いローパスフィルタが実現されていることが分かります.
補足
(*) Coraデータセット: 論文の引用ネットワークデータ, 各論文を適切なトピックに分類するタスク, 各論文が持つ特徴量は, 論文に含まれるBoW. 後で出てくる, Citeseer, Pubmedも同様な論文の引用ネットワークデータである.
よく使われるグラフデータにおける, タスクの精度と周波数成分の関係
文献 [2] では, よく使われるグラフデータにおいて, タスクの精度と周波数成分の関係が分析されています.
次の図は, 3つの論文引用ネットワークデータ, Cora, Citeseer, Pubmedに関する実験結果で, グラフ信号データを周波数成分ごとに分けた場合に, 低周波成分から順々に加えていくことを考えます(つまり, 横軸が の場合は, オリジナルのグラフ信号全体を表す). その割合ごとで、多層パーセプトロンモデル(MLP)を用いた場合のノード分類タスクの精度がプロットされています(また, データに追加するノイズの量ごとに, 3つの曲線がプロット)*.
(図) 文献 [2]
まず, 着目するのが, 3つのデータでどれも低周波成分が2割いかないくらいの時に, 精度が最も高くなり, それ以降減少していっているということです. これは, ノード分類タスクに対して, 有用な特徴が低周波成分に多く含まれており, 高周波成分には無駄な情報が多いようなケースになっているというように考えられます.
また, 点線は, グラフ信号データを全て使用した場合の各モデルの最高精度を表しています. 特に, 緑点線がMLPモデル, オレンジ点線がGNNモデルの精度を表しています**. 3つのデータで比較してみると, オレンジと緑の点線の差, つまりGNNの効果の大きさは, 曲線の減衰の大きさと関係していることが分かります.
つまり, 低周波成分の中にタスクに有用な特徴が多く含まれており, 高周波成分には無駄な特徴が多いようなデータに対して, GNNが効果的となっているように考えられます. それは, 上記でも議論したように, GNNのローパスフィルタとしての性質が大きく効いているからではないかというように考えることが出来ます.
補足
(*) ここで使用している3つのデータは, どれも論文引用ネットワークデータで, ノード同士のつながりの情報(隣接行列)と, 各ノードが持つ特徴量(BoW)の情報になっています. つながりの情報を使わなければ, 通常の特徴量を用いたノード分類問題となるため, つながり情報を使わない場合のMLPと, つながり情報を使うGNNの比較を考えることが出来ます.
(**) 上記では, オレンジ点線がGNNを意味すると記載しましたが, 引用した図上ではgfNNと書かれています. これは, GNNがローパスフィルタとしての役割が大きいという考察から, 論文で提案されている新たなモデルのことです.
下記が, 通常のGNN(GCN)と, SGC(Simple Graph Convolution, [1]), そしてgfNN[2]のモデルを比較した図になっています:
(図) 文献 [2]
gf(A) がフィルタ部分に相当します. SGCでは, フィルタを数回掛けて, それに対して重みを掛けるというシンプルな構成になっています. 一方で, gfNNでは, 活性化関数を挟み, さらに重みを掛けるという構成になっています.
様々なフィルタとノード分類問題の精度
最後に, 様々なフィルタを使用した場合の, それぞれの分類精度の変化を見ます. これに関して実験した結果が下図となります(SGCモデル[1]での実験結果).
(図) 文献 [1]
ここで, それぞれのフィルタは,
$$ \begin{align} &S_{adj} = D^{-1/2}AD^{-1/2},\quad \tilde{S}_{adj} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} \\ &S_{rw} = D^{-1}A,\quad \tilde{S}_{rw} = \tilde{D}^{-1}\tilde{A} \\ &S_{1-order} = I + D^{-1/2}AD^{-1/2} \end{align} $$
を表します. また, 横軸は, フィルタを何回作用させたか表し, つまりGNNの層の深さを意味します.
まず, 赤とオレンジのフィルタが最も精度が良くて, 層を深くしてことに対しても安定しているように見えます. これは, 先ほどの renormalization trick のフィルタに相当しており, ローパスフィルタとして良い性質を持っていることから, このような結果が得られていると考えられます.
一方で, 紫のフィルタは, 層を増やしていくと, 精度が悪化していきます. これは, 先ほど見たように, 次数を増やしていくと, 低周波成分が増幅され過ぎてしまうということが, 分類タスクに悪影響を与えていると考えられます.
また, 緑と青のフィルタは, 奇数層で精度が大幅に悪化していることが分かります. これも先ほど見たように, セルフループを入れていないことで, 奇数次の場合に高周波成分が負の値に増幅されることを見ました. それが分類タスクに悪影響を与えていると考えられます.
先ほどのフィルタの周波数成分ごとの作用の分析と照らし合わせて, ローパスフィルタとしての機能が分類タスクに大きく影響を与えているのではないかとここから考えることが出来ます.
5. まとめ
GNNは, 信号処理の観点でローパスフィルタの効果を持っている. そして, GNNが通常のMLPに比べて高精度となるようなデータでは, 低周波成分に(タスクに関して)有用な情報が多く含まれており, 高周波成分には無駄な情報が多いようなケースになっているように見受けれられる. また, 実際にタスクの精度を様々なフィルタで比較すると, 信号処理の観点での性質が反映された結果を与えている. GNNの最も大きな役割は, ローパスフィルタとしての性質の影響なのではないかと考えられる. 逆に言うと, ローパスフィルタとしての性質が生かせるようなデータが, GNNの効果が最も期待出来るものなのではないかと考えられる.
参考
[1] Simplifying Graph Convolutional Networks, F. Wu et al, ICML2019
[2] Revisiting Graph Neural Networks: All We Have is Low-Pass Filters, Hoang NT and T.Maehara, arXiv:1902.07153
[3] Semi-Supervised Classification with Graph Convolutional Networks, T.Kipf and M.Welling, ICLR2017
[4] Qiita グラフ畳み込み再考, @cotton-gluon, https://qiita.com/cotton-gluon/items/5c4e2f9c2c8a120863fa
[6] グラフ信号処理のすゝめ, 田中 雄一, https://www.jstage.jst.go.jp/article/essfr/8/1/8_15/_pdf