論文紹介: Deep Graph Infomax

Deep Graph Infomax (https://arxiv.org/abs/1809.10341) について今回紹介します. 個人的に,とても注目している技術なので,まとめてみました. もし,間違いなどがあれば,ご指摘いただけると大変助かります.よろしくお願いします.

この論文は,ICLR 2019のポスターに採択されており,Graph Attention Networks (https://arxiv.org/abs/1710.10903) や GraphSAGE (https://arxiv.org/abs/1706.02216) といったGraph Neural Networksの代表的な論文の著者達による研究になります.

概要

最近提案された"Deep Infomax"と呼ばれる教師なしの表現学習手法を,グラフデータに適用したもの.

何がすごいのか?

得られた表現 (特徴量) をロジスティック回帰で学習させたものが,教師あり学習のGCNと同等以上の精度を達成していること.


[論文中のTable2に記載の実験結果 (ノードのクラス分類タスク)] f:id:masashi16:20190519215704p:plain

モチベーション

まず,この論文でやりたいことは,「グラフデータにおける教師なしでのノードの表現学習」,つまりノードエンベディングを求めることである.

代表的な手法としては,Deep Walk (https://arxiv.org/abs/1403.6652) が挙げられる.Deep Walk は,グラフデータ上でランダムウォークを行うことで,グラフデータを系列データの集合に置き換えて,そこにword2vecを適用することで,ノードごとの表現を学習させる.


[Deep Walkのイメージ図]

f:id:masashi16:20190519215843p:plain

  • メリット:

    • つながりとしての近さを,表現空間にうまく反映させられる
    • word2vecの利用により,学習がスケーラブル


他の手法としては,SDNE (https://www.kdd.org/kdd2016/subtopic/view/structural-deep-network-embedding) などのように,隣接行列を再構成させるように表現を学習させるオートエンコーダもあるが,計算コストが高いことがネックである.

また,GCN (Graph Convolutional Networks; https://arxiv.org/abs/1609.02907)で教師なし表現学習を学習させる方法が,GraphSAGE (https://arxiv.org/abs/1706.02216) で提案されているが,ここでもランダムウォークで得られる共起に基づく学習となっている.

今回のモチベーションとしては,ランダムウォークに依存しない表現学習の手法を提案することにある.

Deep Infomax

まず,基になっている手法であるDeep Infomax (https://arxiv.org/abs/1808.06670) について述べる.

相互情報量基準

グラフデータに限らず,従来の表現学習手法の代表的なものとしては,オートエンコーダが挙げられる. しかし,オートエンコーダは,入力全体の再構成誤差を最小にするという基準のため,例えば画像データで言えば,本質的には関係ない背景などの細かな部分も学習に影響を与えてしまうという問題がある.

Deep Infomaxでは,再構成誤差の代わりに相互情報量を基準として用いる.相互情報量を利用した表現学習の例としては,infoGAN や \beta-VAE がある.入力と潜在表現の間の相互情報量を高くするような制約をつけて学習することで,disentangled な表現の獲得に寄与することが分かっている.

[例: infoGANでMNISTを学習させた結果 (https://arxiv.org/abs/1606.03657)] f:id:masashi16:20190519215927p:plain

生成器で用いるパラメータにおいて,c_1 は数字のタイプ,c_2 は文字の回転,c_3 は文字の太さに対応するように学習された結果が得られる.

グローバルな情報の取り込み

ただ,再構成誤差を相互情報量に変えただけでは,上記で述べたオートエンコーダの問題は本質的には解決されていない.例えば,下図ではCNNのオートエンコーダを用いて猫の表現学習を行っており,その際に猫に関連する情報の他にも,関係しない情報も多く抽出してしまう:

f:id:masashi16:20190519220004p:plain (https://aisc.a-i.science/static/slides/20190411_KaranGrewal.pdf)

では,出来るだけ上図で言うところの relevant-information を多く取り出すにはどうすればよいだろうか? そもそもCNNではローカルな情報しか見ていないため,relevant-informationとは何であるのかということが分からない.そこで,何らかのグローバルな情報をローカルなパッチに与えてやることで,関連する情報とは何であるかをローカル部分にも教えてやるというようなことを考える. ここでは,CNNの各パッチの情報を1つにまとめた "global summary feature" というものを考えて,それと各パッチ間の相互情報量を最大にするように学習させるという戦略を取る.

少し意味が分かりにくいと思うので,(おそらく) やりたいことの気持ちみたいなものを書いてみる.この学習は,self-supervisionと呼ばれるラベルなしでどのように学習タスクを設定するかという問題である.そのような学習の方法として,例えば,オートエンコーダ,また文章表現学習において,Skip-thought などがある.


f:id:masashi16:20190519215850p:plain

Deep infomaxで学習しているところの気持ちを書くと上図のようなものになると考えている(論文には具体的にこのようなことは記載されていないので,ミスリーディングを促しているかもしれないのでご留意ください). ちなみに,この辺りに関しては,Microsoftのブログ (https://www.microsoft.com/en-us/research/blog/deep-infomax-learning-good-representations-through-mutual-information-maximization/)にも記載があります.

MINE (Mutual Information Neural Estimator; https://arxiv.org/abs/1801.04062)

上記で,相互情報量を考えるといったが,具体的にどのように計算すればいいだろうか?最近,相互情報量のスケーラブルな計算手法としてMINEが提案されている.

$$ \begin{align} MI(X, Y) &= H(X) - H(X | Y) \\ &= D_{KL} (P_{XY} || P_X P_Y ) \\ &= \sup_{T}~ \mathbb{E}_{x,y\sim P_{XY}} [T(x,y)] -\log \mathbb{E}_{\substack{x\sim P_X \\ y\sim P_Y}} [e^{T(x,y)}] \end{align} $$

ここで,MI相互情報量 Hエントロピーを表す.1行目と2行目は相互情報量の定義から導かれる.3行目が,KL-divergence の Donsker-Varadhan表現と呼ばれる表現である (導出は補足参照). MINEでは,このDonsker-Varadhan表現を用いることで,次のようにニューラルネットワーク相互情報量を推定するフレームワークに落とす:


f:id:masashi16:20190525161134p:plain

これは,後述するが,対応があるポジティブサンプル {(x^{(i)}, y^{(i)})} と対応がないネガティブサンプル {(x^{(j)}, \tilde{y}^{(j)})} を区別するようなモデルを学習させることに対応している.

補足

Deep Infomaxの論文では,この他にも学習させる分布に制約をつけたり,確率分布間の距離をKL-divergence以外のものに変えるなど様々な工夫が行われている.確率分布間の距離の変更に関しては後述する.

Deep Graph Infomax

Deep Graph Infomaxは,上記で説明してきたDeep Infomaxをグラフデータに適用したものである.

問題設定

N 個のノードがあり,隣接行列 A\in \mathbb{R}^{N \times N} が与えられているとする. 隣接行列は,ノード ij の間にエッジがあるなら要素 A_{ij} \neq 0 となるような行列である. さらに,ノードごとの属性情報行列 X\in \mathbb{R}^{N\times F} が与えられているとする. このような状況で,各ノードごとの表現ベクトル  \{ h_n \in \mathbb{R}^{d} \}_{n=1}^{N}~~(d \ll N) を求めることを考える.

Deep Infomaxのグラフデータへの適用

1. まず,(X, A) を何らかの方法でエンコードする: $$ \begin{align} (X, A) ~~ \overset{E_\psi}{\Longrightarrow} ~~\{h_n \}_{n=1}^{N} \end{align} $$

2. 1でエンコードした \{h_n \}_{i=1}^{N} を用いて,全てのノード表現をサマライズした global sumarry feature として,s を構築する

3. 一方で,(X, A) とは異なるデータ (\tilde{X}, \tilde{A}) を用意し,上記と同様のエンコードを行い,\{\tilde{h}_m \}_{m=1}^{M} を求める: $$ \begin{align} (\tilde{X}, \tilde{A}) ~~ \overset{E_\psi}{\Longrightarrow} ~~\{\tilde{h}_m \}_{m=1}^{M} \end{align} $$

4. 対応があるポジティブペアの集合 \{(h_1, s), (h_2, s), \cdots, (h_N, s) \} と, 対応がない ネガティブペアの集合 \{(\tilde{h}_1, s), (\tilde{h}_2, s), \cdots, (\tilde{h}_M, s)\} に対して, 先ほどのMINEの目的関数を最大にするように,関数 T_\theta を求めることで,各ノード表現とグローバル表現の間の相互情報量の推定を行う

5. さらに,相互情報量を最大にするように,エンコーダ部分 E_\psi を学習させるが,このエンコーダ部分 E_\psi とMINEの T_\theta 部分のニューラルネットの層を共有させることで,相互情報量の推定と最大化を一気に行うように学習させる

以上のフレームワークにより,各ノードの表現と,グローバル表現の間の相互情報量を最大にするように学習させることで,ノード表現 \{h_n \}_{i=1}^{N} を求める.

この論文での具体的な設定

エンコーダ部分,サマライズの方法,異なるグラフデータの選択方法は様々な選択肢が考えられる.この論文では,具体的に下記のような設定をおく:


  • エンコーダ: GCN (Graph Convolutional Networks) を利用

    GCNは,各ノード n ごとに隣接ノード \mathcal{N}(n)  と自分自身の特徴量を重みつきで足し上げる以下のような更新ルールでニューラルネットワークを構築する:

    
\begin{align}
h_{n}^{(l+1)} = \sigma \Big{\{} \Big{(}  \frac{1}{c_{n}} h_{n}^{(l)} + \sum_{m \in \mathcal{N}(n)} \frac{1}{c_{nm}} h_{m}^{(l)} \Big{)} \, W^{(l)} \Big{\}}
\end{align}

    ここで, c_{n},  c_{nm} はノードに依存した定数, W^{(l)}ニューラルネットワークで学習させるパラメータ行列である.


  • サマライズの方法: 

    
\begin{align}
s = \sigma \Big{(} \frac{1}{N} \sum_{n=1}^{N} h_n \Big{)},~~~~\sigma: \text{シグモイド関数}
\end{align}

    単純に,全てのノード表現の平均を取る.直感的に,relevant-informationは全てのノードに共通して含まれるようなものだと考えられるので,平均を取ることでrelevant-informationが高められ,逆にirelevant-informationは薄まってくれると期待できる.論文では,いくつか試した中で,平均を取るものが最も良かったと報告されている.


  • 異なるグラフデータ (\tilde{X}, \tilde{A}) の選択:

    今回は,\tilde{A} = A としオリジナルデータと同一に取り, \tilde{X} は行をランダムシャッフルしたものを選ぶ.この選択は任意性があるが,Appendix C で,この選択による影響の差は大きくないということが確認されている.

目的関数の変更

相互情報量は,KL-divergenceとして表され,MINEでは Donsker-Varadhan 表現で書き換えることでニューラルネットで推定問題を解く形に帰着させていた.ここで,KL-divergence を Jensen-Shannon divergence (JS-divergence) に置き換えることを考える. JS-divergence は,KL-divergence を対称にしたものである.


\begin{align}
D_{JS}(P_{XY}||P_X P_Y) = \frac12 \Big{\{} D_{KL}\Big{(} P_{XY}|| \frac{P_{XY}+P_X P_Y}{2} \Big{)} +D_{KL}\Big{(} P_X P_Y \,|| \, \frac{P_{XY} + P_X P_Y}{2} \Big{)} \Big{\}}
\end{align}

例えば,下図は A の人工的に生成した混合ガウス分布 P を,シングルガウス分布 Q で推定したもので,B, C, D の3つの別々な推定方法での比較した結果となる:

f:id:masashi16:20190519215930p:plain (https://arxiv.org/abs/1511.05101)

B は \text{arg}\min_{Q} D_{KL}(P||Q),D は PQ を逆にした \text{arg}\min_{Q} D_{KL}(Q||P) で推定した結果となる.C が,\text{arg}\min_{Q} D_{JS}(P||Q) で推定した結果に対応しており,KL-divergence の推定よりも良い推定結果を与えていることが分かる.

例えば,GANを考えてみると,生成器は理想的に JS-divergence を最小にすることにより学習される. また,GANでは JS-divergence を,より広いクラスである f-divergence や,Wasserstein距離といったものに置き換えることにより,より安定した学習になることが分かっている.この辺りは,例えば,From GAN to WGAN (https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html) が詳しい.

ここでは,KL-divergence を JS-divergence に置き換えることを考える:

$$ \begin{align} MI(X, Y) &= D_{KL}(P_{X, Y} \, || \, P_X P_Y) \\ & \to D_{JS}(P_{X, Y} \, || \, P_X P_Y) \geq \sup_{D} (-\mathcal{L}_{BCE}(D)) + \log 4 \end{align} $$

ここで,\mathcal{L}_{BCE} は,Binary Cross Entropy誤差関数,D はその識別器を表す関数である.

上記の最後の不等式部分は,例えば,GANで最適な識別器 D=D^{*} を学習させた際の目的関数が, V(D^{*}) = - \mathcal{L}_{BCE}(D^*) = 2D_{JS} - \log 4 となり,識別器の学習と,JS-divergence の推定の問題はつながっていることからも分かる (詳細は補足参照).

これにより結局,ローカルなノード表現 {h_n } とグローバルな表現 s の間の相互情報量を最大化するという問題は,\{(h_n, s)\}\{(\tilde{h}_m, s)\} の間の識別器を学習する問題に帰着される.

Deep Graph Infomaxまとめ

今まで見てきたことは,論文の3.4 OVERVIEW OF DGIにあるように下記でまとめられる: f:id:masashi16:20190519215903p:plain

この論文では,具体的な設定を置いているが,フレームワークとしては,上記の 別データを作る関数 \mathcal{C}, エンコードする関数 \mathcal{E}, サマライズする関数 \mathcal{R}, 識別器 \mathcal{D} は柔軟に設定してやることが出来る.

f:id:masashi16:20190519215854p:plain

実験

ノードのクラス分類タスク

f:id:masashi16:20190519215958p:plain

ここでは,

  • Transduvtive: 1つのグラフデータだけ与えられていて,その内のいくつかのノードラベルがますくされている半教師ありの設定

  • Inductive: 訓練とテストで,異なるグラフデータが与えられる設定

の2つの設定での実験を考える.

表現学習で得られた特徴量を,ロジスティック回帰で学習を行う.結果としては,冒頭にも載せたような結果となる (分類精度):

f:id:masashi16:20190519215707p:plain [論文中のTable2に記載の実験結果 (ノードのクラス分類タスク)]

ここで特筆すべきが,ラベル情報 Y も使っている教師あり学習のGCNと同等以上の精度を記録していることである.

2次元での可視化

Coraデータセットで,表現学習を行い,t-SNEで2次元に次元圧縮して可視化したのが下図である:

f:id:masashi16:20190519215953p:plain

見て分かるようにうまく分離されており,また従来のシルエットスコアよりも高い値になっているとのことである.

まとめ

最近提案されたDeep Infomaxと呼ばれる新たな表現学習手法をグラフデータに適用することで,従来のランダムウォークに依存しないノード表現学習を可能にした. そして,ノード分類のタスクにおいて,教師あり学習のGCNと同等以上の精度を記録.



補足

KL-divergenceのDonsker-Varadhan表現

$$ \begin{align} D_{KL} (P || Q ) = \sup_{T}~ \mathbb{E}_{p\sim P} [T(p)] -\log \mathbb{E}_{\substack{q\sim Q}} [e^{T(q)}] \end{align} $$

ここでは離散的な場合,T(i) = t_iで示す.右辺が最大となる T を求めると, $$ \begin{align} &\frac{\partial}{\partial t_j} \Big{(} \sum_{i} p_i t_i -\log \sum_i q_i e^{t_i} \Big{)} =0 \\ &\Rightarrow ~~ p_j - \frac{q_j e^{t_j}}{\sum_i q_i e^{t_i}} = 0 \\ &\Rightarrow ~~ t_j = \log \frac{p_j}{q_j} + \underbrace{\log \sum_i q_i e^{t_i}}_{=\alpha\,\text{とおく}} \end{align} $$ これを,実際に右辺に代入すると, $$ \begin{align} \sum_{i} p_i t_i -\log \sum_i q_i e^{t_i} &= \sum_i p_i \Big{(} \log \frac{p_i}{q_i} + \alpha \Big{)} -\log \sum_i q_i \exp \Big{(} \log \frac{p_i}{q_i} + \alpha \Big{)} \\ &= \sum_i p_i \log \frac{p_i}{q_i} = D_{KL}(P || Q) \end{align} $$

JS-divergenceの下限表現

JS-divergenceを含む f-divergence の下限表現は,f-GAN (https://arxiv.org/abs/1606.00709) で求められている.

f-divergenceとは, $$ \begin{align} D_f(P || Q) = \int q(x) \,f\Big{(} \frac{p(x)}{q(x)} \Big{)}\,dx \end{align} $$ ここで,f は凸関数で,f(1)=0 を満たすものである. 例えば,f(u) = u \log u, ~u=p(x)/q(x) とすると,KL-divergenceとなる.

f の凸性から,Fenchel conjugateを考える(例えば,PRMLの10.5が詳しい.凸関数は接線によって下から抑えられるという性質を用いて導かれる双対表現). $$ \begin{align} f^{*}(t) =\sup_{u \in \text{dom}_f} (ut -f(u)) \\ f(u) = \sup_{t \in \text{dom}_{f^{*}}}(tu - f^{*} (t)) \end{align} $$

上記の f-divergenceを,Fenchel conjugateで書き直すと, $$ \begin{align} D_f (P||Q) &= \int q(x) \sup_{t \in \text{dom}_{f^{*}}} \Big{(} t \frac{p(x)}{q(x)} - f^{*} (t) \Big{)} \\ &\geq \sup_T \Big{[} \int p(x) T(x) dx -q(x) f^{*} (T(x))dx \Big{]} ~~~ (\because \text{Jensen不等式+$\alpha$})\\ &= \sup_T ~ \mathbb{E}_{x\sim P}[T(x)] -\mathbb{E}_{x\sim Q} [f^{*} (T(x))] \end{align} $$

これが,一般的な f-divergenceの下限表現となる.特に,今回のJS-divergenceの場合の表現を求めてみると, $$ \begin{align} D_{JS}(P_{XY}||P_X P_Y) \geq \sup_S ~ \mathbb{E}_{P_{XY}}[S(x,y)] - \mathbb{E}_{P_X P_Y}[D_{JS}^{*}(S(x,y))] \end{align} $$ JS-divergence のFenchel conjugate D_{JS}^{*} は, $$ \begin{align} D_{JS}^{*}(S) = - \log (2-e^{S}) \end{align} $$ となるが,ここで関数 S の定義域は,S \lt \log 2 となる.そこで,定義域が \mathbb{R} 全体となるように,以下のように S から T に変換を行う: $$ \begin{align} S(x,y) = \log 2 -\log (1+e^{-T(x,y)}) \end{align} $$ この関数の取り方には任意性はあるが,ここでは f-GAN の論文にあるようにGANの設定とコンシステントになるように選んでいる.

すると, $$ \begin{align} D_{JS}(P_{XY}||P_X P_Y) &\geq \sup_T ~ \mathbb{E}_{P_{XY}} [\log 2 -\log(1+e^{-T(x,y)})] - \mathbb{E}_{P_X P_Y} [-\log (2 -\frac{2}{1+e^{-T(x,y)}})] \\ &= \sup_T ~ \mathbb{E}_{P_{XY}} [-sp(-T(x,y))] - \mathbb{E}_{P_X P_Y} [sp(T(x,y))] +\log 4 \\ &= \sup_T ~ (-\mathcal{L}_{BCE}(T)) + \log4 \end{align} $$ ここで,sp(x) は,softplus関数である.

参考