GNNとグラフ信号処理

概要

GNNは, 信号処理の観点でローパスフィルタの効果を持っている. そして, GNNが通常のMLPに比べて高精度となるようなデータでは, 低周波成分に(タスクに関して)有用な情報が多く含まれており, 高周波成分には無駄な情報が多いようなケースになっているように見受けれられる. また, 実際にタスクの精度を様々なフィルタで比較すると, 信号処理の観点での性質が反映された結果を与えている. GNNの最も大きな役割(通常のMLPとの違い)は, ローパスフィルタとしての性質の影響なのではないかと考えられる. 逆に言うと, ローパスフィルタとしての性質が生かせるようなデータが, GNNの効果が最も期待出来るものなのではないかと考えられる.

内容

  1. はじめに
  2. GNNとはどのようなものだったか?
  3. 信号処理のいくつか
  4. GNNのローパスフィルタとしての役割
  5. まとめ

1. はじめに

Graph Neural Network(GNN)は, グラフ分析の領域で近年State of the Artの精度を記録しているが, その表現能力に関しては, いくつか疑問が持たれています. 例えば, GNNのさきがけの論文であるT.Kipf & M.Welling(ICLR2017)では, 層を増やしても精度が向上しないことが報告されています:

f:id:masashi16:20190813131005p:plain

(図)T.Kipf & M.Welling(ICLR2017)における実験結果 (ノードのクラス分類タスク) Testの部分を見ると, 2, 3層での精度が最も良く, それ以降は減少していることが分かります.

これは, 通常のニューラルネットワークとは少し様相が異なるように思えます. これに対し, 最近2つの論文,

  • [1] Simplifying Graph Convolutional Networks, F. Wu et al, ICML2019
  • [2] Revisiting Graph Neural Networks: All We Have is Low-Pass Filters, Hoang NT and T.Maehara, arXiv:1902.07153

でグラフ信号処理の観点から, GNNの特性に関する議論がされています. 今回は, GNNが, グラフ信号処理の観点からどのように考えられるのかを簡単に見ていくことにします.

Note:

ここで記載している内容は, 厳密な議論ではないことに注意してください. 詳細な議論は論文の方の参照をお願いします. また, 記載に誤りがあるかもしれません. もし,間違いなどがあれば,ご指摘いただけると大変助かります.よろしくお願いします.

2. GNNとはどのようなものだったか?

前提

今回は「重みなし無向グラフ」を前提とします.

重みなし, つまり隣接行列は, エッジがある場合は  1, そうでない場合は  0とします. また, 有向グラフの場合, 隣接行列が対称行列とならないため, 理論的な考察が難しくなります.

「重みなし無向グラフ」以外のグラフは, 今回の記事で扱う範囲からは外れることに注意してください.

GNN概観

f:id:masashi16:20190813131012p:plain:w600

(参考) http://tkipf.github.io/misc/SlidesCambridge.pdf

基本的には, 自分自身も含めて, 隣接している特徴量を集約させるところが肝で, それ以降は, 通常のニューラルネットワークと同様に重みパラーメータ  W を掛けて, 活性化関数  \sigma を掛けるということにまとめられます. ここで,  \mathcal{N}(i) はノード  i の隣接ノードの集合を表しています.

一応, よく間違われる注意点をまとめておくと,

  • グラフの構造自体は, 更新では変わらない. 変わるのは, ノードの特徴量のみ.

  • 隣接ノードの情報を集約するが, この集約部分は学習パラメータを持たない. (ある一定の規則で集約させるのが基本, エッジ重みやノードが持つエッジ数など, ただし, GAT [P.Veličković et al, ICLR2018] など発展的なものは学習パラメータを加えている)

  • 集約した後は、単に重み行列を掛けて、活性化関数を掛けるだけで通常のニューラルネットワークと同じ

  • 自分自身の寄与(1項目)と, 隣接ノードの寄与に対する重みを分けるやり方もある (GraphSAGE [W.L.Hamilton et al, NIPS2017] など)

隣接特徴量の足し上げの行列表現

上記で, 隣接ノードに対する特徴量を集約するといったが, これはどのように行列演算として表されるでしょうか?

各ノードの特徴量を  h_i \in \mathbb{R}^{F} とし, これを下記のようにノード数  N 分だけ並べて行列として表します:

$$ \begin{align} H = \begin{pmatrix} h_{1}^{T} \\ h_{2}^{T} \\ \vdots \\ h_{N}^{T} \end{pmatrix} \end{align} $$

そこに, 隣接行列  A を掛けると,

$$ \begin{align} AH = \begin{pmatrix} a_{11}h_{1}^{T}+a_{12}h_{2}^{T}+\cdots+a_{1N}h_{N}^{T} \\ a_{21}h_{1}^{T}+a_{22}h_{2}^{T}+\cdots+a_{2N}h_{N}^{T} \\ \vdots \\ a_{N1}h_{1}^{T}+a_{N2}h_{2}^{T}+\cdots+a_{NN}h_{N}^{T} \end{pmatrix} \end{align} $$

隣接行列は, ノード  i j がエッジを持つ場合,  (i, j) 成分はノンゼロであるような行列なので,

$$ \begin{align} (AH)_{i,\,:} = \sum_{j \in \mathcal{N}(i)} a_{ij}\,h_{j}^{T}, \quad \text{where}~~\mathcal{N}(i)\,\text{: ノード $i$ の隣接ノードの集合} \end{align} $$

隣接しているノード  j に関する特徴量の和になります. つまり, 隣接行列  A を掛けることにより, 隣接ノードの特徴量の集約が実現出来ることが分かります.

GNNの行列表現

さらに, GNNでは, (1) 次数正規化, (2) セルフループの追加を行います.

(1) 次数正規化:

隣接ノードの特徴量を集約するのですが, その際に, エッジ  (i, j) に関して, それぞれの端ノードの次数で正規化します:

$$ \begin{align} A \to D^{-1/2}A D^{-1/2}, \quad \text{where}~~ D=\text{diag}(\sum_{i=1}^{N}a_{1i}, \sum_{i=1}^{N}a_{2i}, \cdots, \sum_{i=1}^{N}a_{Ni}) \end{align} $$

これは, 成分表示で表すと,

$$ \begin{align} (D^{-1/2}A D^{-1/2})_{ij} = \frac{1}{\sqrt{\sum_{u=1}^{N}a_{iu}\sum_{v=1}^{N}a_{iv}}}\,\,a_{ij} \end{align} $$

となります.

(2) セルフループの追加:

隣接ノードの特徴量を集約しますが, それよりも自分自身の特徴量の寄与の方が重要であると考えられます. そのため,

$$ \begin{align} A \to I + A \end{align} $$

隣接行列に, 単位行列を加えることで, セルフループの寄与を加えます.

(1)と(2)を行うことで, 結局,


\begin{align}
H^{(l+1)} = \sigma\Big{(} (I+D^{-1/2}AD^{-1/2})H^{(l)} W^{(l)} \Big{)}
\end{align}

となります. ただし, 実はこれではうまくいかないことが分かっており, GNNでは, renormalization trick (T.Kipf & M.Welling(ICLR2017)) と呼ばれる正規化をした下記が使われます:


\begin{align}
H^{(l+1)} = \sigma\Big{(} (\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2})H^{(l)} W^{(l)} \Big{)}
, \quad \text{where}~~\tilde{A}=I+A, \quad \tilde{D}=I+D
\end{align}

上記の2つを,  i 行成分だけ書いてみると,


\begin{align}
\big{\{}\,(I+D^{-1/2}AD^{-1/2})\,H \, \big{\}}_{i, \, :} =
h_{i}^{T} + \sum_{j \in \mathcal{N}(i)}\frac{1}{\sqrt{d_i d_j}}\,h_{j}^{T}
, \quad\text{where}~~d_{i} = \sum_{j=1}^{N} a_{ij}
\end{align}

これは, 隣接する部分は正規化して, 自分自身の寄与は最大値 1で集約させるような効果を示しています.

それに対し, renormalization trick の方は,

$$ \begin{align} \big{\{}\,(\tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2})\,H \, \big{\}}_{i,\, :} &= \big{\{} \, (\tilde{D}^{-1} + \tilde{D}^{-1/2}A\tilde{D}^{-1/2})\,H \, \big{\}}_{ij} \\ &= \frac{1}{(d_{i}+1)}\,h_{i}^{T} + \sum_{j \in \mathcal{N}(i)} \frac{1}{\sqrt{(d_{i} +1)(d_{j}+1)}}\,h_{j}^{T} \end{align} $$

これだけだと, この trick の効果は分かりづらいと思います. これをグラフ信号処理の観点で, どのような効果があるのかを見ていくのが, この記事の本題となります.

3. 信号処理のいくつか

GNNをグラフ信号処理の観点で考える前に, 信号処理の部分に関して, ここで必要となる知識だけ簡単にまとめておきます.

この辺りは, Qiitaの記事,

がとても分かりやすく, 今回の記事も多く参考にしています.

フィルタ

グラフ上の信号処理に入る前に, 通常の信号処理の基本的なところをおさらいします. ただし, ここでは時間変化する信号は扱いません.

まず, 信号処理の重要なキーワードの一つとして, 信号の特徴を抽出する「フィルタ」があります. 数学的に書くと, 信号  f(x) に対して, フィルタを通した後の信号  \tilde{f}(x) は,


\tilde{f}(x) = \int_{-\infty}^{\infty} w(s) f(x-s) ds

と表されます. ここで,  w(s) は窓関数と呼ばれ, フィルタの性質を表します. 例えば, 矩形窓と呼ばれる次のような窓関数を考えます,

f:id:masashi16:20190813131050p:plain:w480

これを作用させると, $$ \begin{align} \tilde{f}(x) = \frac{1}{2W} \int_{-W}^{W} f(s)\, ds \end{align} $$

となり, ある一定幅における平均を抽出するフィルタを表していることが分かります.

もう一つ代表的な例として, ガウシアンフィルタがあります.

f:id:masashi16:20190813131058p:plain:w430

このフィルタの特徴としては, スケール  \sigma に対して, 大きな幅で振動する緩やかな波の場合は, より強調された信号が抽出されます. 一方で, スケール  \sigma に比べて, 短い幅で激しく振動する波の場合, 平滑化された波が抽出されます. つまり, 低周波成分は強く, 高周波成分を弱める, 「ローパスフィルタ」の役割をしています.

f:id:masashi16:20190813131106j:plain:w500

(図) Qiita グラフ畳み込み再考, https://qiita.com/cotton-gluon/items/5c4e2f9c2c8a120863fa

フーリエ変換ラプラシアン

一般に, 信号  f(x) は,  \cos \sin の様々な周波数の波に分解することが出来ます(フーリエ変換).

f:id:masashi16:20190813131124p:plain:w450

例えば, 一つの  \cos 波に, ガウシアンフィルタを作用させてみると,

$$ \begin{align} \tilde{f}(x) &= \int_{-\infty}^{\infty} w(s) f(x-s)\,ds \\ &= \frac{1}{\sqrt{2\pi \text{σ}^{2}}} \int_{-\infty}^{\infty} \,e^{-s^{2}/2 \text{σ}^{2}}\, \cos [\omega (x-s)]\,ds \\ &= \cdots \\ &= e^{-\frac{\text{σ}^{2}}{2}\omega^{2}} \,\cos(\omega x) \end{align} $$ (この計算では, オイラーの公式ガウス積分の複素版を使っています)

ここから, ガウシアンフィルタは, 波の形は変えずに, 振幅を減衰する作用を与えていることが分かります. 特に, 周波数  \omega が大きいほど減衰の大きさも大きくなることが分かります.  \sin 波に関しても同様の結果が得られます.

また, このフィルタの作用は,

$$ \begin{align} \tilde{f}(x) = &= e^{-\frac{\text{σ}^{2}}{2}\omega^{2}} \,\cos(\omega x) \\ &=\exp\Big{[} -\frac{\text{σ}^{2}}{2} \underbrace{\Big{(} -\frac{d^{2}}{dx^{2}} \Big{)}}_{=\Delta} \Big{]}\, \cos(\omega x) \end{align} $$

として, 微分作用素, ラプラシアン  \Delta の観点で書き直すことが出来ます. つまり, ラプラシアンと信号の周波数が結びついていることが分かります.

そして, ここから,


\begin{align}
\Delta\, \cos(\omega x) = \omega^{2} \cos(\omega x), \quad
\Delta\, \sin(\omega x) = \omega^{2} \sin(\omega x)
\end{align}

ということも分かります. つまり, フーリエ変換の一つ一つの波の成分は, ラプラシアンに対する固有関数となっており, また, その固有値は周波数の2乗になっていることが分かります. フーリエ変換は, ラプラシアンという微分作用素の固有関数による展開として捉えることが出来ます.

グラフ上での信号処理

では, グラフ上の信号処理に関して考えていきます. グラフ構造上で, 各ノードが何か信号値を持っているような状況を想定します.

f:id:masashi16:20190813131213p:plain:w300

(参考)グラフ信号処理のすゝめ, https://www.jstage.jst.go.jp/article/essfr/8/1/8_15/_pdf (赤い線がグラフ構造, 青い線が各ノード上の信号値を表す)

このようにグラフ上での信号を考える場合, フーリエ変換はどのように表すことが出来るでしょうか?

先ほどの通常の信号処理で見てきたように, ラプラシアンを導入することによって, フーリエ変換を定義出来ると考えられます.

では, グラフ上での信号の微分はどのように導入出来るでしょうか? ここでは, 直感的な説明となりますが,

に沿って説明します.

次のようなグラフ信号を例として考えてみます:

f:id:masashi16:20190813131141p:plain:w350

ここで, 次のような接続行列(Incidence Matrix) K を考えます,

f:id:masashi16:20190813131231p:plain:w350

これは, ノード数  \times エッジ数の行列で, 各エッジがあるところに対応するノードがあるところに値を持ちます. この時, ノード番号が小さいものの値を  1, 大きいものの値を  -1とします. このような行列を導入すると,


\begin{align}
K^{T} f = \begin{pmatrix}
f_{1} - f_{2} \\
f_{2} - f_{3}  \\
f_{2} - f_{4} \\
f_{4} - f_{5}
\end{pmatrix}
= \begin{pmatrix}
-2 \\
2 \\
1 \\
1
\end{pmatrix}
\end{align}

となり, 各エッジごとの信号値の差分が得られます. つまり, この  K微分作用素として見なすことが出来そうです. すると, ラプラシアン  \Delta は, 2階微分作用素(勾配の発散)なので,

$$ \begin{align} \Delta = K K^{T} &= \begin{pmatrix} 1 & 0 & 0 & 0 \\ -1 & 1 & 1 & 0 \\ 0 & -1 & 0 & 0 \\ 0 & 0 & -1 & 1 \\ 0 & 0 & 0 & -1 \end{pmatrix} \begin{pmatrix} 1 & -1 & 0 & 0 & 0 \\ 0 & 1 & -1 & 0 & 0 \\ 0 & 1 & 0 & -1 & 0 \\ 0 & 0 & 0 & 1 & -1 \\ \end{pmatrix} \\ &= \begin{pmatrix} 1 & -1 & 0 & 0 & 0 \\ -1 & 3 & -1 & -1 & 0 \\ 0 & -1 & 1 & 0 & 0 \\ 0 & -1 & 0 & 2 & -1 \\ 0 & 0 & 0 & -1 & 1 \end{pmatrix} \end{align} $$

というように表されます. これは, グラフラプラシアンと呼ばれ, 先ほどの信号処理の観点から, この行列の固有値固有ベクトルを求めることで, グラフ上の各周波数成分の波が分かります.

このグラフラプラシアンは, 対角成分を見てみると, 各ノードのエッジ数の合計値が入っていることが分かります( [1, 3, 1, 2, 1]). また, 各行ごとに非対角成分を見てみると, エッジを持っている成分に  -1 が入っていることが分かります. つまり,


\Delta = KK^{T} = D - A, \quad
\text{where}~~ D=\text{diag}(\sum_{i=1}^{N}a_{1i}, \sum_{i=1}^{N}a_{2i}, \cdots, \sum_{i=1}^{N}a_{Ni})

という関係になっていることが分かります. また, 色々面白い性質を持っています. 特に, 正規化されたグラフラプラシアン,

 \Delta_{norm}=D^{-1/2}\Delta D^{-1/2}=I-D^{-1/2}AD^{-1/2}

固有値は,  [0, 2]となります(最小値  0, 最大値  2). この時, 固有値 0 \to 2」 は, 「低周波  \to 高周波」に対応することになります.

例えば, 固有値  0 に対応する固有ベクトルは, (連結グラフの場合) 全ての成分が  1 のベクトルになります. これは, グラフ上で変化しない波に対応するため, 低周波の波として見なすことが出来ます.

4. GNNのローパスフィルタとしての役割

ようやく本題に入ります. GNNをグラフ信号処理の観点で考えます. ここでは, 文献 [1] のセクション3の議論に沿って説明します.

まず, 最初に導入したGNNに関して考えていきます.


H^{(l+1)} = \sigma\Big{(} (I+D^{-1/2}AD^{-1/2})H^{(l)} W^{(l)} \Big{)}

ここで,  H を信号値とし,  (I+D^{-1/2}AD^{-1/2}) をフィルタと見なした時のフィルタの効果に関して考えます. 周波数ごとの解析を行うために, このフィルタを, 正規化グラフラプラシアン  \Delta_{norm} の観点で書き直す. すると,


I+D^{-1/2}AD^{-1/2} = 2I - \Delta_{norm}

と表せます. 信号値  H を周波数ごとの波で分解したとすると, 各周波数成分の波に対して,  \Delta_{norm} の作用は, 固有値として各周波数(の2乗)を返します. つまり, 各周波数ごとの波に対して, このフィルタの効果は,  2-\lambda,~(0\leq \lambda \leq 2) が掛かることに対応します.

このフィルタを, Coraデータセット(論文引用ネットワーク)*を用いて表してみると,

f:id:masashi16:20190813131254p:plain:w350

(図) 文献 [1](与えられたネットワークデータに対して, グラフラプラシアンを求め, その固有値をそれぞれ求めてプロット)

となります. 低周波部分( 0 に近い成分)が増幅され, 高周波部分( 2 に近い成分)が減衰するローパスフィルタになっていることが分かります. ただし, 低周波成分は, 次数  K が高くなると(層を増やしていく場合), 極端に信号値が増幅され過ぎてしまうことが分かります.

これに対して, 次に, 単純に正規化したフィルタ  D^{-1/2}AD^{-1/2} を考えてみます. この場合,  D^{-1/2}AD^{-1/2}=I-\Delta_{norm} となり, 同様に図で表すと,

f:id:masashi16:20190813131302p:plain:w350

先ほどとは異なり, 低周波成分のフィルタの最大値は  1 となり, 層を増やしていっても極端な増幅はなくなることが分かります. 一方で, 高周波成分を見てみると, 偶数次では増幅され, 奇数次では負の値になってしまったりとローパスフィルタとしての機能は失われてしまっていることが分かります.

ここで, renormalization trick として導入したフィルタ  (\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2}) に関して考えてみます. このフィルタは, 正規化グラフラプラシアンでなく, セルフループ付きの正規化グラフラプラシアン  \tilde{\Delta}_{norm} = \tilde{D}^{-1/2} \Delta \tilde{D}^{-1/2} の観点で表すことができ,


(\tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2})
= I - \tilde{\Delta}_{norm}

となります. ここで,  \tilde{\Delta}_{norm}固有値を考えると, 実は最大固有値 2 より小さくなるという事実があります(文献 [1] の付録参照). すると, このフィルタの効果は,

f:id:masashi16:20190813131308p:plain:w350

となり, 最大固有値が小さくなったことから, 高周波成分でのおかしな振る舞いが改善されて, より良いローパスフィルタが実現されていることが分かります.

補足

(*) Coraデータセット: 論文の引用ネットワークデータ, 各論文を適切なトピックに分類するタスク, 各論文が持つ特徴量は, 論文に含まれるBoW. 後で出てくる, Citeseer, Pubmedも同様な論文の引用ネットワークデータである.

よく使われるグラフデータにおける, タスクの精度と周波数成分の関係

文献 [2] では, よく使われるグラフデータにおいて, タスクの精度と周波数成分の関係が分析されています.

次の図は, 3つの論文引用ネットワークデータ, Cora, Citeseer, Pubmedに関する実験結果で, グラフ信号データを周波数成分ごとに分けた場合に, 低周波成分から順々に加えていくことを考えます(つまり, 横軸が  1 の場合は, オリジナルのグラフ信号全体を表す). その割合ごとで、多層パーセプトロンモデル(MLP)を用いた場合のノード分類タスクの精度がプロットされています(また, データに追加するノイズの量ごとに, 3つの曲線がプロット)*.

f:id:masashi16:20190813131327p:plain:w800

(図) 文献 [2]

まず, 着目するのが, 3つのデータでどれも低周波成分が2割いかないくらいの時に, 精度が最も高くなり, それ以降減少していっているということです. これは, ノード分類タスクに対して, 有用な特徴が低周波成分に多く含まれており, 高周波成分には無駄な情報が多いようなケースになっているというように考えられます.

また, 点線は, グラフ信号データを全て使用した場合の各モデルの最高精度を表しています. 特に, 緑点線がMLPモデル, オレンジ点線がGNNモデルの精度を表しています**. 3つのデータで比較してみると, オレンジと緑の点線の差, つまりGNNの効果の大きさは, 曲線の減衰の大きさと関係していることが分かります.

つまり, 低周波成分の中にタスクに有用な特徴が多く含まれており, 高周波成分には無駄な特徴が多いようなデータに対して, GNNが効果的となっているように考えられます. それは, 上記でも議論したように, GNNのローパスフィルタとしての性質が大きく効いているからではないかというように考えることが出来ます.

補足

(*) ここで使用している3つのデータは, どれも論文引用ネットワークデータで, ノード同士のつながりの情報(隣接行列)と, 各ノードが持つ特徴量(BoW)の情報になっています. つながりの情報を使わなければ, 通常の特徴量を用いたノード分類問題となるため, つながり情報を使わない場合のMLPと, つながり情報を使うGNNの比較を考えることが出来ます.

(**) 上記では, オレンジ点線がGNNを意味すると記載しましたが, 引用した図上ではgfNNと書かれています. これは, GNNがローパスフィルタとしての役割が大きいという考察から, 論文で提案されている新たなモデルのことです.

下記が, 通常のGNN(GCN)と, SGC(Simple Graph Convolution, [1]), そしてgfNN[2]のモデルを比較した図になっています:

f:id:masashi16:20190813131354p:plain:w300

(図) 文献 [2]

gf(A) がフィルタ部分に相当します. SGCでは, フィルタを数回掛けて, それに対して重みを掛けるというシンプルな構成になっています. 一方で, gfNNでは, 活性化関数を挟み, さらに重みを掛けるという構成になっています.

様々なフィルタとノード分類問題の精度

最後に, 様々なフィルタを使用した場合の, それぞれの分類精度の変化を見ます. これに関して実験した結果が下図となります(SGCモデル[1]での実験結果).

f:id:masashi16:20190813131414p:plain:w800 (図) 文献 [1]

ここで, それぞれのフィルタは,

$$ \begin{align} &S_{adj} = D^{-1/2}AD^{-1/2},\quad \tilde{S}_{adj} = \tilde{D}^{-1/2}\tilde{A}\tilde{D}^{-1/2} \\ &S_{rw} = D^{-1}A,\quad \tilde{S}_{rw} = \tilde{D}^{-1}\tilde{A} \\ &S_{1-order} = I + D^{-1/2}AD^{-1/2} \end{align} $$

を表します. また, 横軸は, フィルタを何回作用させたか表し, つまりGNNの層の深さを意味します.

まず, 赤とオレンジのフィルタが最も精度が良くて, 層を深くしてことに対しても安定しているように見えます. これは, 先ほどの renormalization trick のフィルタに相当しており, ローパスフィルタとして良い性質を持っていることから, このような結果が得られていると考えられます.

一方で, 紫のフィルタは, 層を増やしていくと, 精度が悪化していきます. これは, 先ほど見たように, 次数を増やしていくと, 低周波成分が増幅され過ぎてしまうということが, 分類タスクに悪影響を与えていると考えられます.

また, 緑と青のフィルタは, 奇数層で精度が大幅に悪化していることが分かります. これも先ほど見たように, セルフループを入れていないことで, 奇数次の場合に高周波成分が負の値に増幅されることを見ました. それが分類タスクに悪影響を与えていると考えられます.

先ほどのフィルタの周波数成分ごとの作用の分析と照らし合わせて, ローパスフィルタとしての機能が分類タスクに大きく影響を与えているのではないかとここから考えることが出来ます.

5. まとめ

GNNは, 信号処理の観点でローパスフィルタの効果を持っている. そして, GNNが通常のMLPに比べて高精度となるようなデータでは, 低周波成分に(タスクに関して)有用な情報が多く含まれており, 高周波成分には無駄な情報が多いようなケースになっているように見受けれられる. また, 実際にタスクの精度を様々なフィルタで比較すると, 信号処理の観点での性質が反映された結果を与えている. GNNの最も大きな役割は, ローパスフィルタとしての性質の影響なのではないかと考えられる. 逆に言うと, ローパスフィルタとしての性質が生かせるようなデータが, GNNの効果が最も期待出来るものなのではないかと考えられる.

参考

論文紹介: Deep Graph Infomax

Deep Graph Infomax (https://arxiv.org/abs/1809.10341) について今回紹介します. 個人的に,とても注目している技術なので,まとめてみました. もし,間違いなどがあれば,ご指摘いただけると大変助かります.よろしくお願いします.

この論文は,ICLR 2019のポスターに採択されており,Graph Attention Networks (https://arxiv.org/abs/1710.10903) や GraphSAGE (https://arxiv.org/abs/1706.02216) といったGraph Neural Networksの代表的な論文の著者達による研究になります.

概要

最近提案された"Deep Infomax"と呼ばれる教師なしの表現学習手法を,グラフデータに適用したもの.

何がすごいのか?

得られた表現 (特徴量) をロジスティック回帰で学習させたものが,教師あり学習のGCNと同等以上の精度を達成していること.


[論文中のTable2に記載の実験結果 (ノードのクラス分類タスク)] f:id:masashi16:20190519215704p:plain

モチベーション

まず,この論文でやりたいことは,「グラフデータにおける教師なしでのノードの表現学習」,つまりノードエンベディングを求めることである.

代表的な手法としては,Deep Walk (https://arxiv.org/abs/1403.6652) が挙げられる.Deep Walk は,グラフデータ上でランダムウォークを行うことで,グラフデータを系列データの集合に置き換えて,そこにword2vecを適用することで,ノードごとの表現を学習させる.


[Deep Walkのイメージ図]

f:id:masashi16:20190519215843p:plain

  • メリット:

    • つながりとしての近さを,表現空間にうまく反映させられる
    • word2vecの利用により,学習がスケーラブル


他の手法としては,SDNE (https://www.kdd.org/kdd2016/subtopic/view/structural-deep-network-embedding) などのように,隣接行列を再構成させるように表現を学習させるオートエンコーダもあるが,計算コストが高いことがネックである.

また,GCN (Graph Convolutional Networks; https://arxiv.org/abs/1609.02907)で教師なし表現学習を学習させる方法が,GraphSAGE (https://arxiv.org/abs/1706.02216) で提案されているが,ここでもランダムウォークで得られる共起に基づく学習となっている.

今回のモチベーションとしては,ランダムウォークに依存しない表現学習の手法を提案することにある.

Deep Infomax

まず,基になっている手法であるDeep Infomax (https://arxiv.org/abs/1808.06670) について述べる.

相互情報量基準

グラフデータに限らず,従来の表現学習手法の代表的なものとしては,オートエンコーダが挙げられる. しかし,オートエンコーダは,入力全体の再構成誤差を最小にするという基準のため,例えば画像データで言えば,本質的には関係ない背景などの細かな部分も学習に影響を与えてしまうという問題がある.

Deep Infomaxでは,再構成誤差の代わりに相互情報量を基準として用いる.相互情報量を利用した表現学習の例としては,infoGAN や \beta-VAE がある.入力と潜在表現の間の相互情報量を高くするような制約をつけて学習することで,disentangled な表現の獲得に寄与することが分かっている.

[例: infoGANでMNISTを学習させた結果 (https://arxiv.org/abs/1606.03657)] f:id:masashi16:20190519215927p:plain

生成器で用いるパラメータにおいて,c_1 は数字のタイプ,c_2 は文字の回転,c_3 は文字の太さに対応するように学習された結果が得られる.

グローバルな情報の取り込み

ただ,再構成誤差を相互情報量に変えただけでは,上記で述べたオートエンコーダの問題は本質的には解決されていない.例えば,下図ではCNNのオートエンコーダを用いて猫の表現学習を行っており,その際に猫に関連する情報の他にも,関係しない情報も多く抽出してしまう:

f:id:masashi16:20190519220004p:plain (https://aisc.a-i.science/static/slides/20190411_KaranGrewal.pdf)

では,出来るだけ上図で言うところの relevant-information を多く取り出すにはどうすればよいだろうか? そもそもCNNではローカルな情報しか見ていないため,relevant-informationとは何であるのかということが分からない.そこで,何らかのグローバルな情報をローカルなパッチに与えてやることで,関連する情報とは何であるかをローカル部分にも教えてやるというようなことを考える. ここでは,CNNの各パッチの情報を1つにまとめた "global summary feature" というものを考えて,それと各パッチ間の相互情報量を最大にするように学習させるという戦略を取る.

少し意味が分かりにくいと思うので,(おそらく) やりたいことの気持ちみたいなものを書いてみる.この学習は,self-supervisionと呼ばれるラベルなしでどのように学習タスクを設定するかという問題である.そのような学習の方法として,例えば,オートエンコーダ,また文章表現学習において,Skip-thought などがある.


f:id:masashi16:20190519215850p:plain

Deep infomaxで学習しているところの気持ちを書くと上図のようなものになると考えている(論文には具体的にこのようなことは記載されていないので,ミスリーディングを促しているかもしれないのでご留意ください). ちなみに,この辺りに関しては,Microsoftのブログ (https://www.microsoft.com/en-us/research/blog/deep-infomax-learning-good-representations-through-mutual-information-maximization/)にも記載があります.

MINE (Mutual Information Neural Estimator; https://arxiv.org/abs/1801.04062)

上記で,相互情報量を考えるといったが,具体的にどのように計算すればいいだろうか?最近,相互情報量のスケーラブルな計算手法としてMINEが提案されている.

$$ \begin{align} MI(X, Y) &= H(X) - H(X | Y) \\ &= D_{KL} (P_{XY} || P_X P_Y ) \\ &= \sup_{T}~ \mathbb{E}_{x,y\sim P_{XY}} [T(x,y)] -\log \mathbb{E}_{\substack{x\sim P_X \\ y\sim P_Y}} [e^{T(x,y)}] \end{align} $$

ここで,MI相互情報量 Hエントロピーを表す.1行目と2行目は相互情報量の定義から導かれる.3行目が,KL-divergence の Donsker-Varadhan表現と呼ばれる表現である (導出は補足参照). MINEでは,このDonsker-Varadhan表現を用いることで,次のようにニューラルネットワーク相互情報量を推定するフレームワークに落とす:


f:id:masashi16:20190525161134p:plain

これは,後述するが,対応があるポジティブサンプル {(x^{(i)}, y^{(i)})} と対応がないネガティブサンプル {(x^{(j)}, \tilde{y}^{(j)})} を区別するようなモデルを学習させることに対応している.

補足

Deep Infomaxの論文では,この他にも学習させる分布に制約をつけたり,確率分布間の距離をKL-divergence以外のものに変えるなど様々な工夫が行われている.確率分布間の距離の変更に関しては後述する.

Deep Graph Infomax

Deep Graph Infomaxは,上記で説明してきたDeep Infomaxをグラフデータに適用したものである.

問題設定

N 個のノードがあり,隣接行列 A\in \mathbb{R}^{N \times N} が与えられているとする. 隣接行列は,ノード ij の間にエッジがあるなら要素 A_{ij} \neq 0 となるような行列である. さらに,ノードごとの属性情報行列 X\in \mathbb{R}^{N\times F} が与えられているとする. このような状況で,各ノードごとの表現ベクトル  \{ h_n \in \mathbb{R}^{d} \}_{n=1}^{N}~~(d \ll N) を求めることを考える.

Deep Infomaxのグラフデータへの適用

1. まず,(X, A) を何らかの方法でエンコードする: $$ \begin{align} (X, A) ~~ \overset{E_\psi}{\Longrightarrow} ~~\{h_n \}_{n=1}^{N} \end{align} $$

2. 1でエンコードした \{h_n \}_{i=1}^{N} を用いて,全てのノード表現をサマライズした global sumarry feature として,s を構築する

3. 一方で,(X, A) とは異なるデータ (\tilde{X}, \tilde{A}) を用意し,上記と同様のエンコードを行い,\{\tilde{h}_m \}_{m=1}^{M} を求める: $$ \begin{align} (\tilde{X}, \tilde{A}) ~~ \overset{E_\psi}{\Longrightarrow} ~~\{\tilde{h}_m \}_{m=1}^{M} \end{align} $$

4. 対応があるポジティブペアの集合 \{(h_1, s), (h_2, s), \cdots, (h_N, s) \} と, 対応がない ネガティブペアの集合 \{(\tilde{h}_1, s), (\tilde{h}_2, s), \cdots, (\tilde{h}_M, s)\} に対して, 先ほどのMINEの目的関数を最大にするように,関数 T_\theta を求めることで,各ノード表現とグローバル表現の間の相互情報量の推定を行う

5. さらに,相互情報量を最大にするように,エンコーダ部分 E_\psi を学習させるが,このエンコーダ部分 E_\psi とMINEの T_\theta 部分のニューラルネットの層を共有させることで,相互情報量の推定と最大化を一気に行うように学習させる

以上のフレームワークにより,各ノードの表現と,グローバル表現の間の相互情報量を最大にするように学習させることで,ノード表現 \{h_n \}_{i=1}^{N} を求める.

この論文での具体的な設定

エンコーダ部分,サマライズの方法,異なるグラフデータの選択方法は様々な選択肢が考えられる.この論文では,具体的に下記のような設定をおく:


  • エンコーダ: GCN (Graph Convolutional Networks) を利用

    GCNは,各ノード n ごとに隣接ノード \mathcal{N}(n)  と自分自身の特徴量を重みつきで足し上げる以下のような更新ルールでニューラルネットワークを構築する:

    
\begin{align}
h_{n}^{(l+1)} = \sigma \Big{\{} \Big{(}  \frac{1}{c_{n}} h_{n}^{(l)} + \sum_{m \in \mathcal{N}(n)} \frac{1}{c_{nm}} h_{m}^{(l)} \Big{)} \, W^{(l)} \Big{\}}
\end{align}

    ここで, c_{n},  c_{nm} はノードに依存した定数, W^{(l)}ニューラルネットワークで学習させるパラメータ行列である.


  • サマライズの方法: 

    
\begin{align}
s = \sigma \Big{(} \frac{1}{N} \sum_{n=1}^{N} h_n \Big{)},~~~~\sigma: \text{シグモイド関数}
\end{align}

    単純に,全てのノード表現の平均を取る.直感的に,relevant-informationは全てのノードに共通して含まれるようなものだと考えられるので,平均を取ることでrelevant-informationが高められ,逆にirelevant-informationは薄まってくれると期待できる.論文では,いくつか試した中で,平均を取るものが最も良かったと報告されている.


  • 異なるグラフデータ (\tilde{X}, \tilde{A}) の選択:

    今回は,\tilde{A} = A としオリジナルデータと同一に取り, \tilde{X} は行をランダムシャッフルしたものを選ぶ.この選択は任意性があるが,Appendix C で,この選択による影響の差は大きくないということが確認されている.

目的関数の変更

相互情報量は,KL-divergenceとして表され,MINEでは Donsker-Varadhan 表現で書き換えることでニューラルネットで推定問題を解く形に帰着させていた.ここで,KL-divergence を Jensen-Shannon divergence (JS-divergence) に置き換えることを考える. JS-divergence は,KL-divergence を対称にしたものである.


\begin{align}
D_{JS}(P_{XY}||P_X P_Y) = \frac12 \Big{\{} D_{KL}\Big{(} P_{XY}|| \frac{P_{XY}+P_X P_Y}{2} \Big{)} +D_{KL}\Big{(} P_X P_Y \,|| \, \frac{P_{XY} + P_X P_Y}{2} \Big{)} \Big{\}}
\end{align}

例えば,下図は A の人工的に生成した混合ガウス分布 P を,シングルガウス分布 Q で推定したもので,B, C, D の3つの別々な推定方法での比較した結果となる:

f:id:masashi16:20190519215930p:plain (https://arxiv.org/abs/1511.05101)

B は \text{arg}\min_{Q} D_{KL}(P||Q),D は PQ を逆にした \text{arg}\min_{Q} D_{KL}(Q||P) で推定した結果となる.C が,\text{arg}\min_{Q} D_{JS}(P||Q) で推定した結果に対応しており,KL-divergence の推定よりも良い推定結果を与えていることが分かる.

例えば,GANを考えてみると,生成器は理想的に JS-divergence を最小にすることにより学習される. また,GANでは JS-divergence を,より広いクラスである f-divergence や,Wasserstein距離といったものに置き換えることにより,より安定した学習になることが分かっている.この辺りは,例えば,From GAN to WGAN (https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html) が詳しい.

ここでは,KL-divergence を JS-divergence に置き換えることを考える:

$$ \begin{align} MI(X, Y) &= D_{KL}(P_{X, Y} \, || \, P_X P_Y) \\ & \to D_{JS}(P_{X, Y} \, || \, P_X P_Y) \geq \sup_{D} (-\mathcal{L}_{BCE}(D)) + \log 4 \end{align} $$

ここで,\mathcal{L}_{BCE} は,Binary Cross Entropy誤差関数,D はその識別器を表す関数である.

上記の最後の不等式部分は,例えば,GANで最適な識別器 D=D^{*} を学習させた際の目的関数が, V(D^{*}) = - \mathcal{L}_{BCE}(D^*) = 2D_{JS} - \log 4 となり,識別器の学習と,JS-divergence の推定の問題はつながっていることからも分かる (詳細は補足参照).

これにより結局,ローカルなノード表現 {h_n } とグローバルな表現 s の間の相互情報量を最大化するという問題は,\{(h_n, s)\}\{(\tilde{h}_m, s)\} の間の識別器を学習する問題に帰着される.

Deep Graph Infomaxまとめ

今まで見てきたことは,論文の3.4 OVERVIEW OF DGIにあるように下記でまとめられる: f:id:masashi16:20190519215903p:plain

この論文では,具体的な設定を置いているが,フレームワークとしては,上記の 別データを作る関数 \mathcal{C}, エンコードする関数 \mathcal{E}, サマライズする関数 \mathcal{R}, 識別器 \mathcal{D} は柔軟に設定してやることが出来る.

f:id:masashi16:20190519215854p:plain

実験

ノードのクラス分類タスク

f:id:masashi16:20190519215958p:plain

ここでは,

  • Transduvtive: 1つのグラフデータだけ与えられていて,その内のいくつかのノードラベルがますくされている半教師ありの設定

  • Inductive: 訓練とテストで,異なるグラフデータが与えられる設定

の2つの設定での実験を考える.

表現学習で得られた特徴量を,ロジスティック回帰で学習を行う.結果としては,冒頭にも載せたような結果となる (分類精度):

f:id:masashi16:20190519215707p:plain [論文中のTable2に記載の実験結果 (ノードのクラス分類タスク)]

ここで特筆すべきが,ラベル情報 Y も使っている教師あり学習のGCNと同等以上の精度を記録していることである.

2次元での可視化

Coraデータセットで,表現学習を行い,t-SNEで2次元に次元圧縮して可視化したのが下図である:

f:id:masashi16:20190519215953p:plain

見て分かるようにうまく分離されており,また従来のシルエットスコアよりも高い値になっているとのことである.

まとめ

最近提案されたDeep Infomaxと呼ばれる新たな表現学習手法をグラフデータに適用することで,従来のランダムウォークに依存しないノード表現学習を可能にした. そして,ノード分類のタスクにおいて,教師あり学習のGCNと同等以上の精度を記録.



補足

KL-divergenceのDonsker-Varadhan表現

$$ \begin{align} D_{KL} (P || Q ) = \sup_{T}~ \mathbb{E}_{p\sim P} [T(p)] -\log \mathbb{E}_{\substack{q\sim Q}} [e^{T(q)}] \end{align} $$

ここでは離散的な場合,T(i) = t_iで示す.右辺が最大となる T を求めると, $$ \begin{align} &\frac{\partial}{\partial t_j} \Big{(} \sum_{i} p_i t_i -\log \sum_i q_i e^{t_i} \Big{)} =0 \\ &\Rightarrow ~~ p_j - \frac{q_j e^{t_j}}{\sum_i q_i e^{t_i}} = 0 \\ &\Rightarrow ~~ t_j = \log \frac{p_j}{q_j} + \underbrace{\log \sum_i q_i e^{t_i}}_{=\alpha\,\text{とおく}} \end{align} $$ これを,実際に右辺に代入すると, $$ \begin{align} \sum_{i} p_i t_i -\log \sum_i q_i e^{t_i} &= \sum_i p_i \Big{(} \log \frac{p_i}{q_i} + \alpha \Big{)} -\log \sum_i q_i \exp \Big{(} \log \frac{p_i}{q_i} + \alpha \Big{)} \\ &= \sum_i p_i \log \frac{p_i}{q_i} = D_{KL}(P || Q) \end{align} $$

JS-divergenceの下限表現

JS-divergenceを含む f-divergence の下限表現は,f-GAN (https://arxiv.org/abs/1606.00709) で求められている.

f-divergenceとは, $$ \begin{align} D_f(P || Q) = \int q(x) \,f\Big{(} \frac{p(x)}{q(x)} \Big{)}\,dx \end{align} $$ ここで,f は凸関数で,f(1)=0 を満たすものである. 例えば,f(u) = u \log u, ~u=p(x)/q(x) とすると,KL-divergenceとなる.

f の凸性から,Fenchel conjugateを考える(例えば,PRMLの10.5が詳しい.凸関数は接線によって下から抑えられるという性質を用いて導かれる双対表現). $$ \begin{align} f^{*}(t) =\sup_{u \in \text{dom}_f} (ut -f(u)) \\ f(u) = \sup_{t \in \text{dom}_{f^{*}}}(tu - f^{*} (t)) \end{align} $$

上記の f-divergenceを,Fenchel conjugateで書き直すと, $$ \begin{align} D_f (P||Q) &= \int q(x) \sup_{t \in \text{dom}_{f^{*}}} \Big{(} t \frac{p(x)}{q(x)} - f^{*} (t) \Big{)} \\ &\geq \sup_T \Big{[} \int p(x) T(x) dx -q(x) f^{*} (T(x))dx \Big{]} ~~~ (\because \text{Jensen不等式+$\alpha$})\\ &= \sup_T ~ \mathbb{E}_{x\sim P}[T(x)] -\mathbb{E}_{x\sim Q} [f^{*} (T(x))] \end{align} $$

これが,一般的な f-divergenceの下限表現となる.特に,今回のJS-divergenceの場合の表現を求めてみると, $$ \begin{align} D_{JS}(P_{XY}||P_X P_Y) \geq \sup_S ~ \mathbb{E}_{P_{XY}}[S(x,y)] - \mathbb{E}_{P_X P_Y}[D_{JS}^{*}(S(x,y))] \end{align} $$ JS-divergence のFenchel conjugate D_{JS}^{*} は, $$ \begin{align} D_{JS}^{*}(S) = - \log (2-e^{S}) \end{align} $$ となるが,ここで関数 S の定義域は,S \lt \log 2 となる.そこで,定義域が \mathbb{R} 全体となるように,以下のように S から T に変換を行う: $$ \begin{align} S(x,y) = \log 2 -\log (1+e^{-T(x,y)}) \end{align} $$ この関数の取り方には任意性はあるが,ここでは f-GAN の論文にあるようにGANの設定とコンシステントになるように選んでいる.

すると, $$ \begin{align} D_{JS}(P_{XY}||P_X P_Y) &\geq \sup_T ~ \mathbb{E}_{P_{XY}} [\log 2 -\log(1+e^{-T(x,y)})] - \mathbb{E}_{P_X P_Y} [-\log (2 -\frac{2}{1+e^{-T(x,y)}})] \\ &= \sup_T ~ \mathbb{E}_{P_{XY}} [-sp(-T(x,y))] - \mathbb{E}_{P_X P_Y} [sp(T(x,y))] +\log 4 \\ &= \sup_T ~ (-\mathcal{L}_{BCE}(T)) + \log4 \end{align} $$ ここで,sp(x) は,softplus関数である.

参考