【論文読み】Weight Uncertainty in Neural Networks
Summary
ニューラルネットワークの重みに確率分布を導入することで、過学習を防ぎ、Dropoutのように複数のモデルをアンサンブルしたような効果が得られる。Contextual Bandit(後述)のように確率的な意思決定が必要な場合にも適用できる。
Proposed Method
ネットワークの重みに確率分布を導入する。ニューラルネットは入力に対しその出力の分布を与えるモデルとして解釈できる。
通常のベイズ推定の枠組みと同様、データに対する尤度と事前確率を用いて、重みの事後確率を
\begin{align} log P(w|D) \simeq log P(w) + log P(D|w) = log P(w) + \sum_{i} log P(y_i | x_i, w) \end{align}
で計算する。*1
MAP(事後確率最大化)の枠組みではこの確率を最大化することを目指すが、これは通常のニューラルネットにおいて重みについての正則化(ガウス事前分布→L2正規化、ラプラス事前分布→L1正則化)を導入した場合とみなすことができる。
推論の際には事後確率について、それぞれのwに対してその取りうる確率で重みづけした期待値である
\begin{align} P(y|x) = E_{P(w|D)} [P(y|x, w)] = \int P(w|D) P(y|x, w)] dw \end{align}
で計算する。これは重みが異なる無数のニューラルネットワークをアンサンブルしたものと解釈できる。
ただもちろんこの積分値は厳密に計算できない(intractable)ので、変分ベイズ法の枠組みで事後確率をパラメータ化された別の単純な確率分布により近似することで対処する。この近似誤差をKLダイバージェンスによって測る。 \begin{align} KL[q(w|\theta) || P(w|D)] &= \int q(w|\theta) log \frac{q(w|\theta)}{P(w|D)} \\ &= \int q(w|\theta) log \frac{q(w|\theta)}{P(D|w)P(w)} \\ &= KL[q(w|\theta) || P(w)] - E_{q(w|\theta)} [log P(D|w)] \\ &= F(D, \theta) \end{align}
これはに対するパラメータのコスト関数とみなせる。これを最小化するを求めることが目的。さらにこのコストはモデルの複雑さに対するペナルティである1つ目の項、データへの適合度を表している2つ目の項に分解して解釈できる。
通常であれば事前分布とに共役な分布を用いてKLを解析的に計算することが多いが、今回はKL項を解析的に計算できない場合についても勾配法を用いた更新を行う手法について考える。
をに関する勾配法で最適化することを考えたい。ここでパラメータに依存しないノイズが存在し、がある決定的な関数を用いてと書ける時(つまり、確率分布が簡単なノイズ項とその変形として分解できる時)、次が成り立つ。
\begin{align} \frac{\partial}{\partial \theta} E _ {q(w|\theta)} [f(w, \theta)] = E _ {q(\epsilon)} [\frac{\partial f(w, \theta)}{\partial w} \frac{\partial w}{\partial \theta} + \frac{\partial f(w, \theta)}{\partial \theta}] \end{align}
これはGaussian reparameterisation trickと呼ばれる手法の拡張であり、確率モデルが単純な場合の微分連鎖律的なものになっている。
よって
\begin{align} \frac{\partial}{\partial \theta} F(D, \theta) &= \frac{\partial}{\partial \theta} \int q(w|\theta) \log \frac{q(w|\theta)}{P(D|w)P(w)} \\ &= \frac{\partial}{\partial \theta} E _ {q(w|\theta)} [\log q(w|\theta) - \log P(D|w) - \log P(w)] \\ &= E _ {q(\epsilon)} [\frac{\partial f(w, \theta)}{\partial w} \frac{\partial w}{\partial \theta} + \frac{\partial f(w, \theta)}{\partial \theta}] \end{align}
ここでとおいた。この形まで持ってくれば、をに関してサンプリングして中身の微分を計算することで期待値を近似的に計算することができる。この期待値を勾配としてを更新していくことでモデルの最適化を行うことができる。
は無数の辺の重み*2であり、そのそれぞれが互いに独立なガウス分布に従うと仮定する。するとを対角成分としてもつ対角行列を用いてとかける(つまり)。
の条件を入れるためとおき直すと、を用いてとかける。パラメータであり、は要素ごとの積を表す。*3
この条件のもとで \begin{align} \frac{\partial w}{\partial \mu} &= 1 \\ \frac{\partial w}{\partial \rho} &= \frac{\epsilon}{1 + exp(-\rho)} \end{align}
を使うことで、ベイズ化ニューラルネットの全体のアルゴリズムを次のように書ける。
- に基づいてをサンプリングする。
- を計算。
- を計算。
- に対する勾配を次の式で計算する: \begin{align} \Delta_{\mu} = \frac{\partial f(w, \theta)}{\partial w} + \frac{\partial f(w, \theta)}{\partial \mu} \end{align}
- に対する勾配を次の式で計算する。 \begin{align} \Delta_{\rho} = \frac{\partial f(w, \theta)}{\partial w} \frac{\epsilon}{1 + exp(-\rho)} + \frac{\partial f(w, \theta)}{\partial \rho} \end{align}
- 勾配を用いてパラメータ, を更新する。
は通常のバックプロパゲーションで得られる勾配であり教師あり学習の枠組みで計算できる。
この手法ではとを共役にして]を解析的に計算する必要がないため、を自由に選ぶことができる。例えばこの論文では混合ガウス分布を使っている。
訓練データをミニバッチに分割する場合は全体に対して和をとってに一致させる必要があるので、KL項に適切に重みづけする。例えば
\begin{align} F_i^{\pi}(D_i, \theta) = \pi_i KL[q(w|\theta) || P(w)] - E_{q(w|\theta)} [\log P(D_i|w)] \\ \end{align} \begin{align} \pi _ i = \frac{2 ^ {M - i}}{2 ^ M - 1} \end{align}
とすると、学習初期は複雑コストの影響を大きく受け、後半はデータ適合性が重視されるように学習がうまく進むことがわかった。
Experiments
MNISTの分類問題と単純な曲線に対する回帰、Contextual Banditの3つについて実験を行なっている。またネットワークの圧縮についても議論している。
MNIST
性能としてはDropoutと同等の性能。重みの分布を見るとより多くのエッジが協調して予測を行なっていることがわかる(取り除かれると結果に大きく響くクリティカルなエッジが少ない)。
またSingle-to-Noise()に注目して辺の重みの分布を調べており、その値が小さいものから順に辺を除去していくことで、90%付近までほぼ精度を落とさずにネットワークを圧縮できることを示している。
Regression
従来のモデルがデータのない領域では一つの曲線にフィッティングしてしまっているのに対し、提案手法ではデータがないことによる不確実さを評価しちゃんと分散が大きくなっている。
Contextual Bandit
各ステップで食用or毒のキノコの特徴が与えられ、食用のキノコを食べる選択をした場合には+5、毒キノコを食べる選択をした場合には50%で+5か50%で-35の報酬が与えられる(食べない場合は0)という強化学習(?)の問題設定で検証を行なったもの。
予測に確率モデルを導入することで、問題の不確実さを評価しより良い性能が得られていることがわかる。
感想
比較的シンプルな形でニューラルネットに確率を導入できている点が興味深いと思いました。論文でもフレームワークの力を借りれば今すぐにでも実装が可能という点が強調されています。
無限のモデルをアンサンブルできるとは言いつつ結局サンプリングで1-5個しかとってこれないのは正直どうなん?とも思いましたが、Regressionで示されているようにデータ数の少ない場合や強化学習において確率的な環境を扱う場合に有用なのではないかと感じました。
やっぱりベイズ推論系の話はちゃんと数式が出てきて安心しますね笑。ガウス過程+ニューラルネットワークとかもそろそろ出てきているんでしょうか?また調べてみたいと思います。
最後までお読み頂きありがとうございました。