JSAI2019レポート(1)

世間はすでに7月ですが、6月の頭に新潟で開催された日本人工知能学会の個人的に気になった発表のまとめをなんとか書き終えましたので供養しておきます。

簡単な解説とコメント、関連論文などを載せています。個人的に気になった発表だけフォーカスしているので、分野に偏りがあるのはあらかじめご了承ください。強化学習・不変学習系が中心です。

長くなってしまったので2回に分けてお送りしたいと思います。

強化学習

強化学習と模倣学習の融合による人間らしいエージェント

https://confit.atlas.jp/guide/event-img/jsai2019/1Q2-J-2-01/public/pdf?type=in

通常の強化学習によって得られた方策とエキスパートの方策を \alpha : 1 - \alphaで混ぜたものを教師として模倣学習させる。 模倣学習にあたってはGAILというGANの考え方を取り入れた手法があり、与えられた方策がエキスパートかを判別するDiscriminatorと訓練エージェントを敵対的に学習させていく。

方法は単純だが強化学習に「人間らしさ」を持たせる発想は面白いし、これから大事になっていきそう(対人コミュニケーションとか)。

関連: Generative Adversarial Imitation Learning

階層型強化学習における人間のサブゴール知識転移

https://confit.atlas.jp/guide/event-img/jsai2019/1Q2-J-2-02/public/pdf?type=in

複数の候補方策をもち、それら状況に応じて切り替えること(メタ方策)を学習する階層型学習のフレームワークとしてOption-Criticというものがあるらしい。その中でエージェントが自動的に獲得するサブゴール知識を人が与えられるようにしたらどうなるかという提案(実験未完)。

f:id:ey_nosukeru:20190703235630p:plain

関連: Option-Critic

階層型強化学習RGoalアーキテクチャへの再帰呼び出し用スタックの導入

https://confit.atlas.jp/guide/event-img/jsai2019/3D4-OS-4b-01/public/pdf?type=in

RGoalとは、大域的な目的の他にサブゴール gを導入し、現在向かっているサブゴール gを方策 \pi(s, g, a)や価値関数 Q(s, g, a)というように明示的に導入して学習させるようにしたもの。今回の研究では特にサブゴール gに向かっている際の別のサブゴールの設定 g'をスタックの導入によって可能にし、 ((S \rightarrow g') \rightarrow g) \rightarrow Gのような遷移ができるようにした。

実験をみる限り連続空間・高次元な状態へのスケーリングは現時点ではできていなさそうだが、人間の学習能力においてサブゴールという概念は重要な役割を果たしているように思うので今後の進展に注目したいです。

関連: RGoal

深層強化学習エージェントの行動別顕著性マップの生成に関する考察

https://confit.atlas.jp/guide/event-img/jsai2019/3K4-J-2-01/public/pdf?type=in

顕著性マップとは簡単にいうとフレームの一部をマスクしてぼかした時にどれくらいエージェントのパフォーマンスに影響するかを計算して可視化したもの。本研究ではその変動の評価関数の絶対値を外してパフォーマンスにプラスに働くものとマイナスに働くものを区別して可視化できるようにした。

f:id:ey_nosukeru:20190703235730p:plain:w300

関連: Grad-CAM, Visualizing and Understanding Atari Agents

複数の報酬関数を推定可能なタスク条件付き敵対的模倣学習

https://confit.atlas.jp/guide/event-img/jsai2019/4I3-J-2-02/public/pdf?type=in GAILにおいて生成器(エージェント)や識別器にタスクのコンテキストを表す潜在変数 cを導入することで、複数のタスクを扱えるようにしたInfoGAILやconditionalGAILという枠組みがある(モーションキャプチャにおける人の動きの模倣などの応用があるそう)。エントロピー正則化項を導入して学習を安定化させ、迷路上での複数タスク(目的地)で実験を行った。

サブゴール系とも関連しますがコンテキストの導入・マルチタスクという部分が目新しく興味を惹かれました。言及されているモーションキャプチャでの応用やメタラーニングと組み合わせて未知のタスクに対応させるなどできたら面白そうです。

関連: InfoGAIL, conditionalGAIL

大局基準値共有による社会的強化学習

https://confit.atlas.jp/guide/event-img/jsai2019/3K3-J-2-04/public/pdf?type=in

複数のエージェントを走らせつつエージェント間でパフォーマンスを共有し、それを元に基準値を決定し、報酬調整によりそれを上回っているエージェントについては活用(exploitation)を、下回っているエージェントについては探索(exploration)を積極的に行わせるようにしたRisk-Sensitive Satisficingという学習法がある。各状態・行動ペアに対して基準値を設定していた従来手法に対し、大局的(最終的)な結果を元にそれぞれの状態の基準値を適切に設定することで学習を安定化した。エージェントごとにexploitation/explorationをうまく使い分けて、全体としては安定的に学習が進んで行くことが期待される。

人間の競争志向を取り入れた面白い学習法だと思いました。ただ現状状態カウントベースな量を報酬に組み込んで実験も離散的な設定でしかできていないので、うまくDNNと組み合わせてスケール化が進むことを期待したいです。

関連: 満足化強化学習

深層強化学習を用いたWebサイト内行動のレコメンド

https://confit.atlas.jp/guide/event-img/jsai2019/4O2-J-2-01/public/pdf?type=in

エージェントをWebサイト、状態をサイト訪問中のユーザーの特徴量、アクションを推薦するアイテム等として、報酬であるコンバージョン(会員登録や商品購入など運営がユーザーに期待する行動)を最大化すルためのアルゴリズムとして強化学習を適用したという発表。強化学習を推薦に用いるというのをあまり聞いたことがなかったので新鮮でした。

関連: DRN: A Deep Reinforcement Learning Framework for News Recommendation

進化的計算と方策勾配法による学習を用いた3次元制御タスクにおけるマルチタスク深層強化学習

https://confit.atlas.jp/guide/event-img/jsai2019/4Rin1-04/public/pdf?type=in

3次元におけるマルチタスク学習において、通常の勾配を用いる強化学習に遺伝アルゴリズムを組み合わせた手法。基本的には全く勾配計算を行わずランダムに初期化したニューラルネットを交叉したり突然変異させたりして世代を更新していくが、エリート(性能のいいサンプル)だけは実際に勾配計算を行い、最適化したものを次の世代に追加することで学習の促進を促す。

f:id:ey_nosukeru:20190703235146p:plain

提案手法だと複数タスクの学習が混ざってうまく行かないような気もするのですが、マルチタスク学習において有利なパラメータを遺伝的アルゴリズムで最適化するというのはかなり面白いアイデアだと思いました。

不変学習・メタラーニング

ペアワイズニューラルネット距離による不変表現学習

https://confit.atlas.jp/guide/event-img/jsai2019/1I4-J-2-02/public/pdf?type=in

何らかの属性 aに分類される観測値 xについて、目標値 yに関する情報が保存され、かつ aに依存しないような特徴量 zを抽出することを目標とする分野として不変表現学習がある。この a zの依存性の尺度としては条件つきエントロピーが理想であるが、解析的に計算することはできず、この部分を zから aの予測のしづらさとしてエンコーダー E(x)と予測器 M(z)を敵対的に学習させる敵対的特徴学習が提案されているが、実際には不安定な挙動を示すことがある。別の手法としてそれぞれのカテゴリに対応する zの分布 P_a (z)の全ての組み合わせについてその分布間距離 \frac{1}{|A|} \sum _ {a, a'} d( P _ a(z), P _ {a'}(z) )を指標として用いるペアワイズ不変学習を提案している。分布間距離としては2つの分布からのサンプルを区別するよう訓練された識別器の性能によってその近さを評価するニューラルネット距離を用いる。ペアワイズ距離は通常 |A|の2乗に比例する識別器が必要になるが、識別器を特徴抽出部 Gと線形変換部 Hの合成 G \circ Hとして表せると仮定することで、属性数 |A|の出力を持つ普通のニューラルネット1つによって共通化できる。

f:id:ey_nosukeru:20190703235852p:plain

関連: Adversarial Feature Learning

分類性能による制約を考慮した敵対的不変表現学習によるドメイン汎化

https://confit.atlas.jp/guide/event-img/jsai2019/1Q4-J-2-03/public/pdf?type=in

ドメイン汎化(訓練データにないドメインのデータに対しても妥当な推論が可能になることを目指すタスク)では観測値 xの特徴量 zドメイン dに関する情報量を持たないことを目指す不変学習が用いられることが多いが、実際にはドメインと観測値はある程度の相関をもつ場合が多く、その関係が失われることが推論の精度を下げてしまっている可能性がある。特徴抽出に際して H(d|z) = H(d) z dに関する情報を持たないこと)ではなく H(d|z) = H(d|y)(目標値が持っている dについての情報と同じだけ、 z dについての情報を持つこと)を目指し、 p(d|z) p(d|y)のKLダイバージェンスを誤差関数に組み込んで学習させる。これにより、推論精度を損なわない範囲で不変的な特徴量が学習されることが期待される。

f:id:ey_nosukeru:20190703235932p:plain

メタ学習としてのGenerative Query Network

https://confit.atlas.jp/guide/event-img/jsai2019/2Q5-J-2-03/public/pdf?type=in

3次元空間上の複数視点からの画像を入力して訓練することで、任意の視点からの画像を生成することのできるGQN(Generative Query Network)が昨年DeepMindから発表され話題になった。しかしこのモデルをマルチタスク学習のためのメタ学習の枠組みから見るとアーキテクチャ的に余分な変数や依存関係が存在する。この考察に基づいて変数を削減することで、学習の安定性だけでなく性能自体やハイパーパラメータに対する頑健性も向上した。

f:id:ey_nosukeru:20190704000006p:plain:w350

個人的に洗練されたアーキテクチャ改善手法に感動し、メタラーニングに興味を持ったきっかけになった発表です。

関連: GQN, メタ学習(ML-PIP)

次回につづく

モナリザをしゃべらせる技術、Few-Shot Adversarial Learning of Realistic Neural Talking Head Modelsの論文紹介

https://i.gyazo.com/4dc21b010b38144320d51916d1df59e7.gif

arxiv.org

今回はこの論文を紹介します。いよいよハリーポッターの魔法の世界も夢ではなくなってきました感じがしますね。

Background

技術としては画像生成の分野ではお約束のGAN(Generative Adversarial Network)を中心として、style変換に用いられるAda-IN(Adaptive Instance Normalization)やEncoder(Embedder)の出力をGeneratorの各層に入れるアーキテクチャProjection Discriminatorマルチタスク学習・Few-shot learning*1のためのMeta-learningなどが用いられています。

以下に関連リンクを貼っておきます。

Adaptive Instance Normalization

[1703.06868] Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization (cited: 238)

Projection Discriminator

[1802.05637] cGANs with Projection Discriminator (cited: 51, 日本の方の論文!)

MAML

[1703.03400] Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks (cited: 644)

Proposed Method

f:id:ey_nosukeru:20190629224006p:plain
全体のアーキテクチャ

全体のアーキテクチャは上のようになります。

Embedder

 e = E(x, y; \phi)

Embedderは人の画像 xとその顔の向きを表すランドマーク yを入力して受け取り、入力画像についての人物や背景そのものに関する特徴量 eを出力します。*2 この特徴量は顔の向きに依存せず、人物の種類のみの情報を抽出することを目指します。 \phiはネットワークのパラメータ。

Generator

 x = G(y, e; \psi, P)

ランドマークを入力として受け取り、Embedderの出力 e Pで線形変換したスタイル \hat{\psi} = PeをAdaIN(Adaptive Instance Normalization)を用いて各層に注入しながら特徴抽出・生成画像の肉付けを行い、最終的に与えられたスタイルとランドマークに照らして自然な画像を出力します。 \psiがネットワークのパラメータ、 Pが線形変換のパラメータになります。

ここでAdaINは入力を x、スタイルを sとして次の式で計算されます。

 AdaIN (x, s) = \sigma (s) \frac {x - \mu (x)}{\sigma (x)} + \mu (x)

ここで x sがバッチ単位、多チャンネルの画像、 \mu(x) \sigma(x) xのサンプルごと・チャンネルごとの平均と標準偏差(つまり各サンプル・チャンネルの画素値についての統計量)、 \mu(s) \sigma(s)も同様の統計量です。

DNNにおいては特徴ベクトルの平均・標準偏差などの統計量が画像のスタイルを決定する上で重要な役割をもつことが知られており、AdaINによってこれを操作することによって画像のスタイルを変換することが可能になるようです。

以上のようにAdaINの論文[1703.06868] Arbitrary Style Transfer in Real-time with Adaptive Instance Normalizationには書かれているのですが、ここでの e \hat{\psi}は単純なベクトルなので微妙に細部が異なっていそうです。恐らくStyle-Based GANと同様に eからそれぞれのレイヤー、チャンネルごとの \mu(s) \sigma(s)への変換を直接学習するようにしているのではないかと思います。つまり \hat{\psi}がそれぞれのレイヤー、チャンネルごとの \mu \sigmaを集めたものになります。

Discriminator

 r = D(x, y, i; \theta, W, w_0, b)

通常のGAN同様 xが本物か偽物かを判定します。 iはデータセットの種類、つまり人物のインデックスを表しています。 Wは人物ごとの特徴量(つまりeに似た役割を持つもの)を集めてきたもので、 W_iが対象の人物についての特徴量となります( Wもパラメータ)。畳み込み層の出力ベクトルを、 W_i + w_0, bのパラメータを持つ全結合層に入れて rを計算します。 W_iの役割については後で詳しく説明します。

Training

画像の生成方法

Embedderの出力値 eを顔の向きに依存せず、人物のみに依存するパラメータにするために次のような手順を踏みます。

  1. 人物のインデックスiをランダムに選び、そのデータセットからランダムにK個のサンプルフレーム x _ {i} ^ k, y _ {i} ^ kをサンプリングする。
  2.  \hat{e_i} = \frac{1}{K} \sum_{k=1}^K E(x _ i ^ k, y _ i ^ k; \phi)により人物の特徴量を計算。
  3. これを元に推論画像 \hat{x_i} = G(y_i, P\hat{e_i})を計算。

これによって \hat{e_i}の各フレーム(つまり顔の向き)に対する依存性が失われ、不変な特徴量が学習されることが期待されます。

誤差関数

この手順で得られた出力画像を元に次の誤差関数を最小化することを目指します。

 L = L_{content} (\hat{x_i}, x_i) + L_{adv} (\hat{x_i}, x_i, y_i, i) + L_{match} (\hat{e_i}, W_i)

 L_{content}はGeneratorの出力画像 \hat{x_i}と正解の画像 x_iの中身に関する誤差で単純なピクセル二乗誤差からVGGに突っ込んだ時の中間層の特徴量の二乗誤差などが用いられます。

 L_{adv}は出力画像 \hat{x_i}と正解の画像 x_iがDiscriminatorにとってどれだけ見分けづらいか、つまり出力画像のリアルさに関する誤差で、GANの基本となる誤差です。普通は \log D(\hat{x_i}, y_i) + \log (1 - D(x_i, y_i))などが用いられることが多いですが、ここでは次のヒンジ型誤差を使っています(最近の流行り?)。

 L_{adv} (\hat{x_i}, x_i, y_i, i) = max(0, 1 + D(\hat{x_i}, y_i, i)) + max(0, 1 - D(x_i, y_i, i))

 L_{match}はEmbedderによる人物特徴量 e_iとDiscriminatorの人物特徴量 W_iが一致することを要請する特徴量で、単純に L_1誤差が用いられるようです。

 W_iの存在意義

 e_i W_iを近づけるよう学習するなら最初から W_iの部分に e_iを突っ込めばいいやんけとなりそうですが、この工夫には次のような事情が考えられそうです。

  •  e_iはモデルの出力値で訓練中大きく変化する上、Generatorは e_iに大部分依存しているのでこれをDiscriminatorにも使うと学習が不安定になりそう
  • そもそもGeneratorとDiscriminatorでパラメータを共有するのはまずいのでは?(敵対的に学習させるので両者が完全に独立であるべき )

f:id:ey_nosukeru:20190630005823p:plain
メタ学習のイメージ(ML-PIPより)

またそもそもなぜ W_iのようなパラメータを用意するかについてですが、メタ学習は上図のようにタスク(ここではそれぞれの人物)に依存しないパラメータとタスクごとのパラメータを明示的に分け、共通パラメータをうまく学習させることでタスク依存のパラメータの最適化を効率化する、ということが目的のフレームワークです。ここではモデルパラメータ \phi, \psi, \thetaなどが共通パラメータにあたり、人物特徴量 e_i W_iはタスク依存のパラメータになります。 W_iに人物への依存性を集中させることでDiscriminatorのパラメータ \thetaの人物への依存性を減らしていると考えられます。

Fine-Tuning

このプロセスによりモデルの学習が完了したら、次のような手順で推論ができることになります。

  1. 推論対象の人物の画像を複数Embedderに入れて e _ {new} = \frac{1}{T} \sum _ {t = 1} ^ T E(x _ i ^ t) \hat{\psi} = Pe _ {new}を計算
  2. 推論対象のランドマーク y \hat{\psi}をGeneratorに入れて推論結果 \hat{x} = G(y, \hat{\psi})を得る

このまま推論をしてもある程度はうまく行くようですがさらにリアルな画像を得るために次のFine-Tuningを行うと良いみたいです。

  1. 複数の人物に対して推論する必要はもうないのでEmbedderは不要、代わりに直接スタイル \hat{\psi}を最適化する。初期値を \hat{\psi} = Pe_{new}としてGeneratorの新しいパラメータの一部とする。
  2. Discriminatorの W_iも人物ごとに保持しておく必要はなく、 w' = W_i + w_0の部分を直接最適化すれば良い。ただし W_iはわからないので代わりに e _ {new}を使い、初期値を w' = e _ {new} + w_0として新しいパラメータ w'を作る。
  3.  L _ {match}以外の誤差 L = L _ {content} + L _ {adv}により(通常のGANと同様に)パラメータ \psi, \hat{\psi}, \theta, w', bを最適化する。

以上により推論対象の人物に最適化されたパラメータを得ることができます。

Results

結果については公式のビデオを見るのが圧倒的にわかりやすいので割愛します。 L _ {match}やFine-Tuningを使わない場合の比較も行なっています。

www.youtube.com

Comment

GANの部分についてはさほど目新しさはありませんが、メタ学習の考え方を取り入れることで人物依存のパラメータと共通のパラメータをうまく分離して効率的な学習が可能になっているのが興味深い点だと思います。特にFine-Tuningの部分の、前段階の学習の結果の出力値を初期値として新しいパラメータを作り、それを目的のタスクに向けて最適化していくというのがとても面白く、参考になりました。

個人的に最近メタ学習やfew-shot learningに興味が出てきたのでこれからそちらの方向に勉強の範囲を広げていければと思っています。

最後までお読み頂きありがとうございました!

*1:ハイパーパラメータやネットワークの初期値を調整することで、訓練データにないクラスに対してもごく少量のデータを学習することで高い精度での推論を目指す枠組み

*2:Embedderにもランドマークを入力しますが、これは単に顔の向きに関する情報を与えることでそれ以外の特徴抽出を容易にするためだと考えられます。

Transformer完全に理解した

f:id:ey_nosukeru:20190622041311p:plain:w300
BLEACH』170話より

出オチです。nosukeruです。

今回はこの論文に関する解説記事になります。

arxiv.org

読んだのは結構前なのですがわからない部分が多くて放置しており、紆余曲折の末理解できた感動や躓いた部分を共有したく思い記事にまとめました。

Transformerとは

f:id:ey_nosukeru:20190622044350p:plain:w500
イメージ図

2017年12月頃の上の論文でまたもやGoogleチームにより発表され世間の注目を集めた、時系列データ処理に革新をもたらしたモデルです。

  • それまでの時系列データ処理における主役であったRNNやCNNを全く使わないアーキテクチャ
  • Attention機構、特にSelf-AttentionやMulti-Head-Attentionのみを使用
  • 並列処理を可能にしたことにより高速な学習を実現、性能も当時の最高記録を更新

などの目覚ましい特徴により多くの注目を集め、時系列データ・自然言語処理の歴史を塗り替えてしまいました。(引用数2000超え!) これ以降はこれらの分野においてはほとんどのモデルでAttentionが標準的に用いられるようになっていきます。

f:id:ey_nosukeru:20190622044119p:plain:w500
モデルアーキテクチャ

Transformerについては次の@Ryobotさんの記事に大変わかりやすくまとまっており、大いに参考にさせて頂きました。詳細についてはこちらを参照してもらえればよいかと思います。

deeplearning.hatenablog.com

この記事では、Transformerのパーツの中でも特に重要な役割を果たすAttention機構についてと、上の記事ではやや理解しづらかった、学習の仕組みや全体の処理の流れに焦点を当てて説明していきます。

Attentionとは

Attention自体はこの論文以前にあった技術で、人間が画像や文字列を見る際にその一部に注目しているという「注意」の機構を参考にし、ニューラルネットに対してもデータの一部により焦点を当てて特徴を抽出する仕組みを取り入れることで性能の向上を図ったもの、ということができます。

この論文を読む以前はそのぐらいの漠然としたイメージだったのですが、この論文では注意を「メモリデータに対するクエリによる検索」と説明していてしっくりきました。

具体的にはクエリ群 Q \in \mathbb{R}^{D \times N}、辞書のキー K \in \mathbb{R}^{D \times M}、値 V \in \mathbb{R}^{S \times M}に対し、注意によって得られる特徴ベクトルを softmax(QK^{T})Vによって計算します。これは q \in Q k \in Kのそれぞれの内積 q \cdot kを計算し、その値で重み付けして対応する v \in Vを足し合わせることに相当します。内積が大きい \iff特徴ベクトルが似ているということなので、 qに類似したレコードを検索して取ってくる作業(をソフトにしたもの)とみなすことができます。

f:id:ey_nosukeru:20190622045649p:plain
上記事からの引用。辞書 (K, V)に対し複数のクエリ Qで検索を行うイメージ

Attention自体のイメージが掴めれば他はそこまで難しくありません。Self-AttentionはAttentionの Q, K, Vに全て同じデータを入力したもの、Multi-Head Attentionは線形変換をかませてから入力し、この線形変換の組を複数用意して別々にAttentionを計算・組み合わせる機構のことです。

f:id:ey_nosukeru:20190622125223p:plain
Multi-Head Attentionの構造

図中のScaled-Dot Attentionは通常のAttentionにスケールの調整を加えたものです( \text{softmax}(\frac{QK^{T}}{\sqrt{d}})V)。理由としては dが大きくなった場合にsoftmaxの値が一部に集中せず全体的にマイルドな値にするためとのことです。

僕が論文を読んでいて気になった(分からなかった)ポイントとそれに対する答えをいくつか紹介します。

Self-Attentionは同じ位置に対する注意(softmax部分の対角成分)が一番大きくなる?

普通のAttention(かつ特徴ベクトルの大きさが大体等しい場合)では言えそうですが今回の場合は入力の前に Q, K, Vそれぞれに別の線形変換を施しているので、同じ位置の内積が必ず大きくなるとは限りません。線形変換によって特徴量のどこに注目して処理するかを決めてるイメージかなと思います。

Attentionの入力、出力はそれぞれ何?

上の説明ではAttentionをメモリの検索という風に抽象化したので実際に何が入力され、何が出力されるのかイメージするのが難しいかもしれません。

時系列データ、特に翻訳の文脈では K, Vのどちらも参照先の時系列データ X = \{x_1, x_2, ..., x_M \}を入れることが多く、 Qに対象のデータ Y = \{y_1, y_2, ..., y_N \}が入力された場合には(それぞれの特徴ベクトルは d次元ベクトルで共通)、出力として y_iのそれぞれに対するAttentionの計算値 A = \{a_1, a_2, ..., a_N \} a_iはd次元)が得られます。

例えば翻訳では X = \{"これ", "は", "ペン", "です", "。"\}が翻訳前の文章、 Y = \{"This", "is", "a", "pen", "."\}翻訳語の文章で、Xのそれぞれの要素に対して単語の類似度("is"ならば"は"、"pen"ならば"これ"や"ペン")に基づいてその特徴量を足し合わせたものがAttentionである Aということができます。

f:id:ey_nosukeru:20190622133604p:plain
(Self-)Attentionのイメージ図。関連する単語同士の重みが大きくなっている

上の図のように特にSelf-Attentionでは文章内の関連する単語を結びつけるような働きが期待できます。

で、Attentionって結局何者?

僕の場合特徴量同士の類似度を測る、ってところまでは理解できたのですがその後の特徴量を足し合わせる、っていうのが直感的にいまいちピンと来ずモヤモヤしてました。

f:id:ey_nosukeru:20190622134801p:plain
畳み込みとAttentionの比較(http://deeplearning.hatenablog.com/entry/transformerから引用)

これについても@Ryobotさんの記事がとても参考になり、Attentionを畳み込みと比較するのがわかりやすいです。つまり、近傍の要素に重みをつけて特徴量を足し合わせている畳み込みに対し、Attentionでは重みパラメータの代わりに要素同士の類似度によって重みをつけて全ての要素を足し合わせています。つまり、Attentionは「パラメータなしで類似度による重みづけを行う全範囲版の畳み込み」とみなすことができそうです。範囲が限定されないために離れた位置の依存関係を扱うことができること、「メモリの参照」による「キーと値の間の非自明な変換」の能力が、Attentionの表現力の高さの一因となっていると言えると思います。

学習・推論の流れ

学習・推論時にデータをどのように入力し、どのような出力が得られてどのように学習していくのかというのが今回の論文の中で一番理解が難しいところではないかと個人的には思っています。アーキテクチャの図を再掲し、それぞれのレイヤーでのデータがどうなっているのかとともに説明していきます。以下では、日本語を英語に翻訳するタスクを例に考えます。

f:id:ey_nosukeru:20190622044119p:plain
モデルアーキテクチャ(再掲)

Encoder部

まず、左のEncoder部により日本語の文章を何らかの特徴ベクトルに変換していきます。文章をデータとして扱わせるため、まず各単語を何らかの特徴ベクトルに変換します(埋め込み)*1

また例によって効率的な計算のためにバッチ処理を行うので、バッチサイズを b、単語数を n、埋め込みの次元を dとすると入力ベクトルの次元は (b, n, d)になります。これに位置の情報が入るような埋め込み(Positional Embedding)を足し合わせてEncoderへの入力とします。

f:id:ey_nosukeru:20190622160946p:plain
Encoder部

Encoder部ではまず入力に対してSelf-Attentionを計算します。バッチのデータ1つごと、つまり (n, d)のデータに対してSelf-Attentionを計算し、その出力の次元も (n, d)となります。よってSelf-Attention後も入力サイズは変わらず (b, n, d)となります。Attention結果と元のデータをresidual的に足し合わせ、適切に正規化して中間層のデータ (b, n, d)が得られます。

次にFeedForward層によってそれぞれの特徴ベクトルを変換します。この処理は時系列に関する計算は行わず、単純にそれぞれの特徴ベクトルを全結合層により変換します。隠れ層の次元を d_h、出力層の次元を d_oとするとデータの次元は (b, n, d) \rightarrow (b, n, d_h) \rightarrow (b, n, d_o)となります。ただしこの論文では単純に d_o = dとしているため、FeedForward層によって次元は変化せず結局 (b, n, d)となります。

この処理がEncoderのブロック分繰り返されるのですが、最終的な次元は変わらず (b, n, d)になります。これでEncoderの出力が得られました。

Decoder部

ここが最大のポイントなのですが、RNNでは現在の位置の単語を入れて次の単語についての予測結果を得るのに対し、Transformerでは1つの文章(つまり m個の単語)を受け取ってそれに対する m個の予測結果を同時に計算します

 Y = \{y_1 = "This", y_2 = "is", y_3 = "a", y_4 = "pen", y_5 = "."\}という入力を例にとって考えると、Decoderは Yを受け取った結果として \{P(y_1), P(y_2|y_1 = "This"), P(y_3|y_1 = "This", y_2 = "is"), ... \}を出力するということです。つまり i番目の出力は 1から (i - 1)番目の単語を受け取った時の i番目の単語に対する予測結果ということになります。

ただし普通にやると入力でそれぞれの予測の答え y_iが与えられてしまっていてまともな問題にならないので、 y_iを予測する際に y_1から y_{i - 1}までの結果のみを使うようにモデルに制約をかける必要があります。そのためにTransformerではAttentionの部分で未来を参照するような部分の重み(クエリ y_1に対するキー y_3の類似度)を強制的に0にすることでそのような情報の漏洩を防いでいます。

f:id:ey_nosukeru:20190622161015p:plain
Decoder部

さて、Encoderと同様データの流れを見ていくと、まずは英語文のバッチ化された単語の埋め込み (b, m, d)がDecoderに対する入力として与えられます。次に(マスクされた)Self-Attentionを通りますが、データの次元は変わらず (b, m, d)になります。この処理により予測対象の単語より前の単語列の特徴量が計算されます。

次に翻訳前の文章の特徴量を活用するため、Encoderの出力を受け取ります。EncoderからのデータをKey/Valueとして、現在の特徴量をQueryとしてAttentionをかけます。この部分のAttentionの働きは、翻訳後の文章のそれぞれの位置に対して、「その位置の前の単語の特徴量を活用しつつ、翻訳前の文章中で似た特徴量を持つ単語の位置を探してくる」ようなイメージなのかなと思っています。

Attentionの出力データの次元は変わらず (b, m, d)になります。AttentionにおいてはKeyとQueryの長さは一致する必要がないことに注意です(内積を計算するため要素の次元 dは一致する必要あり)。

FeedForward部、ブロックの繰り返しについてはEncoder部と同様です。Decoder部ではこの後単語次元 lへの線形変換がかけられ、softmaxによって各位置の単語が l個の候補単語のどれであるかの確率 P(y_i|y_1, y_2, ... y_{i - 1}) (B, m, l)次元の出力ベクトルとして得られることになります。

学習・推論の流れ

学習の際には、Encoder部への入力として日本語の文章 D_X = \{X_1, X_2, ..., X_b\}を、Decoder部への入力として D_Xに対応する答えである英語の文章 D_Y = \{Y_1, Y_2, ..., Y_b\}を与え、出力としてそれぞれの単語の生起確率 P_Y = \{p_1, p_2, ..., p_b\}が得られます。ここで p _ {ji} = P_j(y_i|y_1, ..., y _ {i - 1})がj個目のデータのi番目の位置での単語の予測結果になっています。この出力に対して正解とのCross Entropyを誤差として学習を行います。

翻訳後の文章を予測させたいのにその正解の文章をモデルに入力として与えているのがかなり奇妙に見えますが、このようにデータを一気に与えて並列計算を可能にし、GPUの性能をフルに使うことができるようにしたのがTransformerの大きな成果の一つです。

推論時には正解の文章が分からないため、Decoder部に D_Yを直接与えることができません。そのため、まずEncoder部に D_Xを入力してその出力を得て、その値を用いて P(y_i|y_1, ..., y_{i - 1})を計算していきます。まず P(y_1)を計算し*2、その確率を最大化するような単語として y_1を決定します。そして得られた y_1を結合してもう一度Decoder部に入力し、 P(y_2|y_1)から y_2を決定します。これを繰り返して文章の長さを1つずつ伸ばしていき、出力として文章の終わりを表す特殊記号が現れた時点で文章の生成を終了します(入力のサイズ: (B, 1, l) \rightarrow (B, 2, l) \rightarrow ...)。

なお、文章を伸ばしていきながら末端の単語を計算する時、副産物として P(y_1) P(y_2|y_1)は毎回得られますがこれらを計算に使うことはありません*3。この無駄な結果を毎回計算するという部分も直感に反し、Transformerの挙動を理解する上で難しかったポイントでした。

最後に

今回はTransformerを理解する上で自分が躓いたポイントを元に解説させて頂きました。少しでも理解の一助となれば幸いです。

間違っている部分などあればご指摘頂ければありがたいです。最後までお読み頂きありがとうございました!

*1:なおここの埋め込みもパラメータ化して学習するようです。

*2:この際の入力には文章の始まりを表す特殊単語を与えます。

*3:ちなみにこれらの依存する変数の内容は変わらないので、これらは毎回同じ値になります。

【論文読み】How Powerful Are Graph Neural Networks ?

arxiv.org

Summary

乱立しているGNNを一つの枠組みで解析・整理した上で、その理論上最も強力なモデルやその条件を提唱し、実際に良い性能が得られることを確認した論文。

Contributionは次の4つ。

  1. グラフの構造を識別する能力において、(後で定義する)GNNのクラスの能力が高々WL-test(後述)以下であることを示した。
  2. グラフの構造を識別する能力がWL-testと同等となるためのGNNに対する条件を示した。
  3. 有名なGNN(GCN, GraphSAGE)が区別できないグラフ構造の例を示した。
  4. 提案するGIN(Graph Isomorphism Network)がWL-testと同等の能力をもつことを示した。

Preliminaries

GNNの定義

GNNのクラスを、その特徴量 h_v ^ {(k)}を次のように計算するモデルの集合と定める。

\begin{align} a _ v ^ {(k)} &= AGGREGATE ^ {(k)} (\left\{ h _ u ^ {(k - 1)} | u \in N(v) \right\}) \\ h _ v ^ {(k)} &= COMBINE ^ {(k)} ( h _ v ^ {(k - 1)}, a _ v {(k)}) \end{align}

 \left\{ h _ u ^ {(k - 1)} | u \in N(v) \right\}は頂点 vの近傍の特徴量のmultiset(重複を許容する集合)を示しており、このように近傍頂点に対する畳み込みをmultiset上の演算と同一視することができる。

GraphSAGE

\begin{align} AGGREGATE ^ {(k)} &= MAX \circ ReLU \circ W \\ COMBINE ^ {(k)} &= W \circ CONCAT \end{align}

で、GCN

\begin{align} AGGREGATE ^ {(k)} &= ReLU \circ W \circ MEAN \\ COMBINE ^ {(k)} &= I \end{align}

で定式化できる。*1

WL-test

特徴量を離散空間に限定し、multisetとみなした近傍頂点の特徴量集合に対して、固有のラベルを割り当てることを考える。これは近傍頂点を根つき木として展開した際に、それぞれの木に固有のラベルをつけて区別することを意味する。

f:id:ey_nosukeru:20190523150723p:plain
近傍を根つき木として展開する

k近傍まで展開した時の特徴量のmultisetを順次比較することで、2つのグラフが等しいかどうかを判定できる。

WL-testは少しでも違う近傍集合を別のものとして区別できるため、グラフ構造の区別能力においては理論上最強である。

Proposed Method

提案手法では

  • GNNで区別できるグラフはWL-testでも区別できること
  • AGGREGATEやCOMBINEが単射であれば、GNNの構造識別能力はWL-testと同等になること

を証明とともに示している。つまりGNNの識別能力についてはAGGREGATEやCOMBINEが単射であること(違うmultisetについてその特徴量も異なること)が重要であると主張している。

さらに、この枠組みのもとで、次の集約関数

\begin{align} h _ v ^ {(k)} = \phi ^ {(k)} \biggl((1 + \epsilon ^ {(k)}) h _ v ^ {(k - 1)} + \sum _ {u \in N(v)} h _ u ^ {(k - 1)} \biggr) \end{align}

( \phi ^ {(k)}: 2層以上のMLP)によって定義されるGNNであるGINを提案し、この識別能力がWL-testと同等であることを示している。

また MAX, MEAN, SUMという3つの集約関数に関して、具体的にそれぞれが区別できないようなグラフの構造を示すことでその能力の大小について議論している。

f:id:ey_nosukeru:20190523152816p:plain
meanやmaxが識別できないmultisetの例。1つ目はmeanもmaxも「どちらも青1」という出力になる。同様に真ん中は「緑1赤1」、右は「緑1赤1」となって区別できない。

f:id:ey_nosukeru:20190523152617p:plain
maxは各要素が存在しているかどうかしか判定できないが、meanはその割合まで判定できる。sumはそれぞれの具体的な数まで区別できる。

Results

f:id:ey_nosukeru:20190523153051p:plain
性能比較。GNN variantsはAGGREGATEやCOMBINEに名前がどんな関数を使ったかを表している。GIN-0は \epsilon ^ {(k)} = 0とした場合。

多くの場合で提案手法が良い性能を示している。 \epsilon = 0の方が全体的に良い性能を出しているのは学習の安定性の問題か。

感想

 \epsilonを学習させずに適当な値( \epsilon = 0.01とか)で学習した場合の結果や、本当にグラフ構造の識別構造だけがGNNの性能に影響するのか(汎化性能、似たものを同一視する能力は?)は疑問の残るところではありますが、GNN全体を論理立てて概観できた感がありとても参考になりました。

*1:GCNはAGGREGATEに h _ u ^ {(k - 1)}も含むのでちゃんと書こうとするともう少し違う形になります。

【論文読み】Semi-Supervised Classification with Graph Convolutional Networks

arxiv.org

Summary

ニューラルネットワークにグラフを入力として与えるための枠組みを与えた論文。グラフ上の畳み込み操作をスペクトル理論から定義してニューラルネットワークがうまく扱える形にした。

Introduction

グラフ上の分類問題として、グラフの構造 G = (V, E)とそれぞれの頂点 v \in Vに割り当てられた値 X_vが与えられ、一部のノードにのみ正解ラベル y_vが与えられている場合に残りのノードのラベルを予測する半教師あり分類問題を考える。接続行列を A (頂点 iと頂点 jが辺で接続されている時 A _ {ij} = 1 , それ以外は A _ {ij} = 0)として、従来手法では

\begin{align} L _ {reg} = \sum _ {i,j} | f(X _ i) - f(X _ j) | ^ 2 = f(X) ^ T \Delta f(X) \end{align}

( \Deltaは非正規化グラフラプラシアン \Delta = D - A) を正則化項として含めることで辺で繋がっている2頂点の値が大きく異ならないように学習させていたが、グラフ上で繋がっている点が必ずしも似たようなラベルを持つとは限らず、これではモデルの表現力を大きく制限してしまう。

Proposed Method

グラフの各ノードに対して通常のニューラルネットと同様に複数層の特徴量を割り当て、前の層の特徴量から次の層の特徴量を次式で計算する。

\begin{align} H ^ { (l + 1) } = \phi ( \tilde {D} ^ { - \frac {1} {2} } \tilde {A} \tilde {D} ^ { - \frac {1} {2} } H ^ { (l) } W ^ { (l) } ) \end{align}

ここで \tilde {A} = A + I,  \tilde {D} _ {ii} = \sum _ j \tilde {A} _ {ij} (対角行列),  H ^ {(l)} l層目の特徴量,  W ^ {(l)} H ^ {(l+1)} H ^ {(l)}間の重み行列,  \phiは活性関数を表す。こうすることでグラフの構造を考慮に入れつつニューラルネットの仕組みを取り入れることができる。重みの更新やノードのラベルの予測は通常のニューラルネットと同様バックプロパゲーションフォワードプロパゲーションにより行える。

Motivation

天下りに導入された計算式だがグラフスペクトル解析のフーリエ変換・フィルタリングの枠組みからある程度正当化できる。*1

グラフ上でのフーリエ変換・逆フーリエ変換は、正規化グラフラプラシアン L = I - D ^ {-\frac{1}{2}} A D ^ {-\frac{1}{2}}固有ベクトル u _ iを用いて次のようにかける。

\begin{align} F _ i &= \sum _ k f(k) u _ i ^ {\ast} (k) \\ f _ k &= \sum _ i F(i) u _ i (k) \end{align}

 f _ kが頂点空間での頂点 kの信号、 F _ iがスペクトル空間での固有値 \lambda _ iに対応する信号である。

グラフ上での畳み込みは、信号処理と同様スペクトル空間での積として定義され、フーリエ変換と逆フーリエ変換を組み合わせて次のようにかける。

\begin{align} g \ast x = U G(\Lambda) U ^ T x \end{align}

ここで G(\Lambda)は対角行列であり、固有値 \lambda _ iのスペクトル成分をどれだけ残すかを表すスペクトル領域でのフィルタである。

ただし、グラフラプラシアン固有値固有ベクトルを計算するのは計算量的に厳しいので、チェビシェフ多項式を使って近似する次の手法が提案されている。

\begin{align} G(\Lambda) & \approx \sum _ {k = 0} ^ K \theta _ k T _ k (\tilde {\Lambda}) \\ \tilde {\Lambda} &= \frac{2}{\lambda_{max}} \Lambda - I \end{align}

チェビシェフ多項式

\begin{align} T _ 0 (x) = 1, T _ 1 (x) = x, T _ k (x) = 2x T _ {k - 1} (x) - T _ {k - 2} (x) \end{align}

で定義される多項式列。 (U \Lambda U ^ T) ^ k = U \Lambda ^ k U ^ Tを用いると,  U T _ k (\tilde {\Lambda}) U ^ T = T _ k (U \tilde {\Lambda} U ^ T) = T _ k (\frac{2}{\lambda _ {max}} L - I)となり、 \tilde {L} = \frac{2}{\lambda_{max}} L - Iとおくと

\begin{align} g \ast x \approx \sum _ {k = 0} ^ K \theta _ k T _ k (\tilde {L}) x \end{align}

が得られ、これはラプラシアン K乗、つまり K次の近傍の頂点のみに依存する形になっている。

ニューラルネットでは層を積み重ねていくため K = 1で十分であり、さらにニューラルネットワークのスケール調整力に期待して \lambda_{max} = 2とすると*2

\begin{align} g \ast x \approx \theta_0 x + \theta_1 (L - I) x = \theta_0 x - \theta_1 D ^ {-\frac{1}{2}} A D ^ {-\frac{1}{2}} x \end{align}

パラメータ数を減らすため \theta = \theta_0 = – \theta_1とし、また数値的な安定性を向上し勾配消失を防ぐために

\begin{align} I + D ^ {-\frac{1}{2}} A D ^ {-\frac{1}{2}} \rightarrow \tilde {D} ^ {-\frac{1}{2}} \tilde {A} \tilde {D} ^ {-\frac{1}{2}} \end{align}

と置き換えることで(この辺の正当性や理屈はよくわからなかった)、特徴量マッピングの式

\begin{align} Z = \tilde {D} ^ { - \frac {1} {2} } \tilde {A} \tilde {D} ^ { - \frac {1} {2} } X \Theta \end{align}

を得ることができる(Renormalization trick)。計算量は \Theta \in R ^ {(C \times F)} ( Cは特徴マップのチャンネル数,  Fは特徴量の次元)として O(|E|FC)となる。

Example

2層のネットワークの場合、 \hat{A} = \tilde {D} ^ { - \frac {1} {2} } \tilde {A} \tilde {D} ^ { - \frac {1} {2} }をあらかじめ計算しておき、出力を

\begin{align} Z = softmax(\hat{A} ReLU(\hat{A} X W ^ {(0)}) W ^ {(1)}) \end{align}

などとして計算する。多クラス分類の場合誤差を

\begin{align} L = - \sum _ {l \in Y_L} \sum _ {f = 1} ^ F Y _ {lf} ln Z _ {lf} \end{align}

( Y_Lはラベルのある頂点の集合) として計算してバックプロパゲーションで勾配から重みを更新する。

Results

f:id:ey_nosukeru:20190523010506p:plain
どのデータセットでも大幅な精度上昇が確認できる

f:id:ey_nosukeru:20190523010717p:plain
Renormalization trickを行った方が精度がよくなっている

感想

本筋ではないですがグラフスペクトル理論の美しさに感動しました。グラフラプラシアンという単純な行列から、普通の信号処理と同様の変換でちゃんと意味のある数字が得られるというのはただただ驚くばかりです。一度ちゃんと勉強してみたい...

GNNを本腰入れて研究するつもりは今のところあまりないですが、何しろトレンドなのでこれからもほどほどにキャッチアップしていきたいと思います。

最後までお読み頂きありがとうございました!

*1:大雑把に眺める上ではグラフ上のラプラシアンとスペクトルについてグラフ信号処理のすゝめがわかりやすかったです。

*2:ちなみに \lambda _ {max} = 2はグラフが2部グラフであることを示すらしいので、あまり現実的な仮定ではありません。

GoConference2019Spring参加記

5/18(土)に東京で開催されたGoConferenceというイベントに、メルカリさんのスカラーシップで参加させて頂きました!

gocon.jp

このイベントは1年に2回東京で開催されており、メルカリさんをはじめとする有志企業によって運営されているそうです。400人を超える参加者を抱えながら参加費を無料に抑えて開催してくださっているというのはすごいことで、本当にスポンサーさんに感謝するばかりです。

tech.mercari.com

今回はメルカリさんのスカラーシッププログラムに申し込ませて頂き、交通費・宿泊費の補助だけでなく、イベント前日にはGoogle本社からのGoTeamの方々も交えたランチ、オフィスツアーや社内勉強会への参加まで提供頂きました。本当に盛りだくさんの内容で良い経験になりました。

Sessions

早速ですが、特に興味を持ったセッションについていくつかご紹介したいと思います。

Keynote

Keynote(オープニングトーク的なやつ)はGoTeamのKatie HockmanさんによるGoProxyに関する話題でした。go getした時に実はソースコード取ってくるのではなく、GoogleProxyサーバーを挟んで色々やってくれているというお話です。

Julieさんも一緒に来て頂いていたGoTeamの一人です。

Modules

go modで生成したgo.modとgo.sumをレポジトリにおいておくと、go get時に関連モジュールを(依存関係を解消する形で)勝手に取ってきてくれるというもの。例えば別の依存パッケージが同じパッケージの別バージョンを参照している場合でも、それらを全て満たすバージョンのものを取ってきてくれるようです。

Mirrors

参照先のソースコードは消えてたり変更される可能性があるので、Proxyサーバーを挟んで1つのパッケージのあるバージョンのソースコードが同一であることを保証しているという内容。具体的にgo getの挙動は次のようになるみたいです。

  1. 対象パッケージの存在するバージョン一覧を取得
  2. 最新のバージョンのソースコードとその情報を取得
  3. go.modを取得して依存先のライブラリを再帰的に取得
  4. 全体のコードをzipとして取得

Checksum Database

取得したソースコードの正当性(第三者によって改竄されていないこと)を保証するためにチェックサムを照合するが、その照合先のチェックサムを1つのデータベースで管理しているという内容です。

www.certificate-transparency.org

上の通りGoの照合システムではマークルツリーという二分木構造を導入しており、それぞれの親ノードのハッシュ値を子ノードのハッシュ値から計算するようにすることで、全体の整合性を保ちながら大量のチェックサムを管理します。GoProxyでは"Trust on One's Use"という考え方に基づき、データベースにレコードがない場合は最初のリクエスト時のハッシュ値を信用してデータベースに保存します。

エラー設計について/Designing Errors

docs.google.com

メルカリの@morikuniさんによるエラー処理の設計に関するトーク。ベストプラクティスだけでなく、そこに行き着くまでのエラー処理そのものの根本的な考え方にも触れられていてすごく参考になりました。

未知のエラーを既知とするための方法がエラーハンドリングである

エラーとは、アプリケーションの実行時に起こる想定外の事態です。そのような未知のエラーを明示的にハンドリングし、自分のアプリケーションの一部として取り込むことがエラーハンドリングだという考え方です。

複数の関係者がエラーにそれぞれ異なる情報を求めている

一口にエラーと言っても処理の階層・エラーを表示する対象によって必要な情報は異なり、それぞれの相手に必要十分な情報を提供できるようにエラーを管理しなければいけないということです。

僕自身もWebAPIをGoで書くにあたりエラー処理が一番の悩みの種で、なかなかこれといった方法論を確立できておらず、それが今回の参加の理由の一つでもありました。特に

「階層的なアーキテクチャを採用している場合レイヤーによってエラーの粒度が異なるが、レイヤーごとに別々のエラーを定義して順次変換していった方がいいのか、同じエラーをレイヤー間で共有してしまっていいのか」

ということについて悩んでいたのですが、今回の「同じエラーでも相手によって提供すべきエラーが異なる」という前提の元では、できるだけ大まかなエラーのみをアプリケーション全体で定義して、詳細やレイヤーごとのメッセージ・情報はレイヤーごとに付加していくのが最善なのかなと思いました。

紹介して頂いたfailure/xerrorsライブラリと合わせて近いうちにリファクタリングしてみたいと思います。

Design considerations for Container based Go application

speakerdeck.com

@hgsgtkさんによる、コンテナアプリケーションをGoで構築するときに考慮すべき事項をまとめたセッションです。こちらも具体的な実装に至るまでの思考過程が整理されており、

という流れが明確でわかりやすく、コンテナアプリケーションに限らず多くの場面で使える考え方であると思いました。

また技術面では

  • テストのしやすさも考慮した環境変数の扱い方
  • ロガーを抽象化してイベントストリームとして扱う
  • ヘルスチェック用のエンドポイントは、API自体だけでなく依存先のサービスのヘルスチェックも行う

などすぐに実践できるテクニックが盛り込まれていたのがとても参考になりました。

CPU, Memory and Go

speakerdeck.com

@sonatardさんによるトーク。GoのCPU・メモリ周りの話から始まり、Goの軽量さの理由や効率的なメモリ管理について紹介して頂きました。

  • appendはスライスのサイズが不足しているときに追加のタイミングでメモリを確保する。ヒープのメモリ確保はとても遅いので、あらかじめサイズがわかっている時は初期化時にまとめて確保する方が望ましい。メモリの空間局所性の観点からも、まとめて確保した方がキャッシュのヒット率が上がるので高速化が期待できる。
  • Escape Analysis: 関数内の変数はスタック領域に確保されるが、関数を抜けるときに必要な変数(他のスコープに渡したポインタなど)はヒープ領域に移される。
  • 変数はメモリアラインメント (64bit/8byte)の単位に沿って確保され、アラインメントを跨ぐような変数確保が発生する場合はパディングにより次のアラインメントに書き込まれる

などのGoの言語仕様を知りました。特に普段何気なく使っている構造体のポインタ渡しなど、メモリ管理など細かい内部処理がうまく隠蔽されており、ユーザーが細かい部分を気にせずとも処理系が自動的に最適化・高速化してくれる使い勝手の良さにGoの人気の一因を垣間見た気がしました。

感想

今回のGoConferenceでは自分の知らない様々な技術を知ることができました。

  • go modによる依存関係・チェックサム管理
  • Builder PatternとFunctional Optional Pattern
  • failure/xerrorsによるエラーハンドリング
  • pprofでのプロファイリング・パフォーマンス解析
  • APIの実装を簡単にするimpl/wire/gotests/fillstructといったパッケージ

同時に自分の知識のなさを痛感し、これから一層頑張って勉強していかないといけないという大きな刺激になりました。学んだことを書き残すだけでなく、今後自分のコードのリファクタリングを通して実践していければと思っています。

おまけ:前日・当日の様子

前日はまずメルカリの社員さん(@tenntennさん、@codehexさん)、Google@ymotongpooさん、さらにGoogleの本社からのGoTeamからの3人とご一緒させて頂きました。会話は全て英語で、Go2(Generics/Contract)やエラー処理のノウハウについて、色々質問させて頂きました。英語力不足で踏み込んだ話がなかなかできなかったのが残念なところです...ちなみにGoTeamで多く使われているエディタはVSCodeVimだそうです。

その後はオフィスツアーと社内勉強会に案内して頂きました。社内勉強会は何か1つのトピックについて話すというよりはそれぞれが持ち寄った話題を共有するという雰囲気でした。今回はスカラーシップ生の質問を中心に色々教えて頂きました。ちなみにメルカリのGoTeamで多く使われているエディタはGoLandとVim、ついで VSCodeやPlayground(嘘です)だそうです。

f:id:ey_nosukeru:20190521151355j:plain:w400
メルカリのオフィス玄関

当日はリクルートさんのオフィス(グラン東京サウスタワー)が会場でした。きれい・広い・眺めがいいと素晴らしい会場でした。

スポンサーさんによるフリースペースもあり、Gopherくん特製カップのコーヒも提供して頂けました。

f:id:ey_nosukeru:20190520113420j:plain:w450
各スポンサーさんのグッズが置かれたフリースペース

f:id:ey_nosukeru:20190520113804j:plain:w450
Gopherくん特製カップ

f:id:ey_nosukeru:20190521151746j:plain:w450
突如始まるGo Quiz

f:id:ey_nosukeru:20190521151904j:plain:w450
懇親会の様子。DeNAさんに提供して頂いたビールがとても飲みやすかったです

参加記は以上になります。最後までお読み頂きありがとうございました!

【論文読み】Weight Uncertainty in Neural Networks

arxiv.org

Summary

ニューラルネットワークの重みに確率分布を導入することで、過学習を防ぎ、Dropoutのように複数のモデルをアンサンブルしたような効果が得られる。Contextual Bandit(後述)のように確率的な意思決定が必要な場合にも適用できる。

Proposed Method

ネットワークの重み wに確率分布 P(w)を導入する。ニューラルネットは入力 x_iに対しその出力 y_iの分布を与えるモデル P(y_i | x_i, w)として解釈できる。

f:id:ey_nosukeru:20190514164911p:plain:w500
ネットワークの重みそれぞれに確率分布を導入する

通常のベイズ推定の枠組みと同様、データに対する尤度と事前確率を用いて、重みの事後確率を

\begin{align} log P(w|D) \simeq log P(w) + log P(D|w) = log P(w) + \sum_{i} log P(y_i | x_i, w) \end{align}

で計算する。*1

MAP(事後確率最大化)の枠組みではこの確率を最大化することを目指すが、これは通常のニューラルネットにおいて重みについての正則化(ガウス事前分布→L2正規化、ラプラス事前分布→L1正則化)を導入した場合とみなすことができる。

推論の際には事後確率 P(w|D)について、それぞれのwに対してその取りうる確率で重みづけした期待値である

\begin{align} P(y|x) = E_{P(w|D)} [P(y|x, w)] = \int P(w|D) P(y|x, w)] dw \end{align}

で計算する。これは重みが異なる無数のニューラルネットワークをアンサンブルしたものと解釈できる。

ただもちろんこの積分値は厳密に計算できない(intractable)ので、変分ベイズ法の枠組みで事後確率をパラメータ化された別の単純な確率分布により近似することで対処する。この近似誤差をKLダイバージェンスによって測る。 \begin{align} KL[q(w|\theta) || P(w|D)] &= \int q(w|\theta) log \frac{q(w|\theta)}{P(w|D)} \\ &= \int q(w|\theta) log \frac{q(w|\theta)}{P(D|w)P(w)} \\ &= KL[q(w|\theta) || P(w)] - E_{q(w|\theta)} [log P(D|w)] \\ &= F(D, \theta) \end{align}

これは Dに対するパラメータ \thetaのコスト関数とみなせる。これを最小化する \thetaを求めることが目的。さらにこのコストはモデルの複雑さに対するペナルティである1つ目の項、データへの適合度を表している2つ目の項に分解して解釈できる。

通常であれば事前分布と q(w, \theta)に共役な分布を用いてKLを解析的に計算することが多いが、今回はKL項を解析的に計算できない場合についても勾配法を用いた更新を行う手法について考える。

 F(D, \theta) \thetaに関する勾配法で最適化することを考えたい。ここでパラメータに依存しないノイズ \epsilonが存在し、 wがある決定的な関数 tを用いて w = t(\theta, \epsilon)と書ける時(つまり、確率分布が簡単なノイズ項とその変形として分解できる時)、次が成り立つ。

\begin{align} \frac{\partial}{\partial \theta} E _ {q(w|\theta)} [f(w, \theta)] = E _ {q(\epsilon)} [\frac{\partial f(w, \theta)}{\partial w} \frac{\partial w}{\partial \theta} + \frac{\partial f(w, \theta)}{\partial \theta}] \end{align}

これはGaussian reparameterisation trickと呼ばれる手法の拡張であり、確率モデルが単純な場合の微分連鎖律的なものになっている。

よって

\begin{align} \frac{\partial}{\partial \theta} F(D, \theta) &= \frac{\partial}{\partial \theta} \int q(w|\theta) \log \frac{q(w|\theta)}{P(D|w)P(w)} \\ &= \frac{\partial}{\partial \theta} E _ {q(w|\theta)} [\log q(w|\theta) - \log P(D|w) - \log P(w)] \\ &= E _ {q(\epsilon)} [\frac{\partial f(w, \theta)}{\partial w} \frac{\partial w}{\partial \theta} + \frac{\partial f(w, \theta)}{\partial \theta}] \end{align}

ここで f(w, \theta) = \log q(w|\theta) - \log P(D|w) P(w)とおいた。この形まで持ってくれば、 \epsilon q(\epsilon)に関してサンプリングして中身の微分を計算することで期待値を近似的に計算することができる。この期待値を勾配として \thetaを更新していくことでモデルの最適化を行うことができる。

 wは無数の辺の重み*2であり、そのそれぞれが互いに独立なガウス分布に従うと仮定する。すると \sigmaを対角成分としてもつ対角行列 \Sigmaを用いて w \sim N(\mu, \Sigma)とかける(つまり q(w|\theta) = N(\mu, \Sigma))。

 \sigma >= 0の条件を入れるため \sigma = \log (1 + exp(\rho))とおき直すと、 \epsilon \sim N(0, I)を用いて w = t(\theta, \epsilon) = \mu + \log (1 + exp(\rho)) \circ \epsilonとかける。パラメータ \theta = (\mu, \rho)であり、 \rhoは要素ごとの積を表す。*3

この条件のもとで \begin{align} \frac{\partial w}{\partial \mu} &= 1 \\ \frac{\partial w}{\partial \rho} &= \frac{\epsilon}{1 + exp(-\rho)} \end{align}

を使うことで、ベイズニューラルネットの全体のアルゴリズムを次のように書ける。

  1.  \epsilon \sim N(0, I)に基づいて \epsilonをサンプリングする。
  2.  w = \mu + \log (1 + exp(\rho)) \circ \epsilonを計算。
  3.  f(w, \theta) = \log q(w|\theta) - \log P(w) P(D|w)を計算。
  4.  \muに対する勾配を次の式で計算する: \begin{align} \Delta_{\mu} = \frac{\partial f(w, \theta)}{\partial w} + \frac{\partial f(w, \theta)}{\partial \mu} \end{align}
  5.  \rhoに対する勾配を次の式で計算する。 \begin{align} \Delta_{\rho} = \frac{\partial f(w, \theta)}{\partial w} \frac{\epsilon}{1 + exp(-\rho)} + \frac{\partial f(w, \theta)}{\partial \rho} \end{align}
  6. 勾配を用いてパラメータ \mu,  \rhoを更新する。

 P(D|w) = \sum_{i} log P(y_i | x_i, w)は通常のバックプロパゲーションで得られる勾配であり教師あり学習の枠組みで計算できる。

この手法では P(w) q(w|\theta)を共役にして KL[q(w|\theta) || P(w)]を解析的に計算する必要がないため、 P(w)を自由に選ぶことができる。例えばこの論文では混合ガウス分布を使っている。

訓練データをミニバッチに分割する場合は全体に対して和をとって F(D, \theta)に一致させる必要があるので、KL項に適切に重みづけする。例えば

\begin{align} F_i^{\pi}(D_i, \theta) = \pi_i KL[q(w|\theta) || P(w)] - E_{q(w|\theta)} [\log P(D_i|w)] \\ \end{align} \begin{align} \pi _ i = \frac{2 ^ {M - i}}{2 ^ M - 1} \end{align}

とすると、学習初期は複雑コストの影響を大きく受け、後半はデータ適合性が重視されるように学習がうまく進むことがわかった。

Experiments

MNISTの分類問題と単純な曲線に対する回帰、Contextual Banditの3つについて実験を行なっている。またネットワークの圧縮についても議論している。

MNIST

f:id:ey_nosukeru:20190514165239p:plain
左: 分類性能 右上: 学習経過 右下: ネットワークの重みの分布

性能としてはDropoutと同等の性能。重みの分布を見るとより多くのエッジが協調して予測を行なっていることがわかる(取り除かれると結果に大きく響くクリティカルなエッジが少ない)。

またSingle-to-Noise( \log \frac{|\mu_i|}{\sigma_i})に注目して辺の重みの分布を調べており、その値が小さいものから順に辺を除去していくことで、90%付近までほぼ精度を落とさずにネットワークを圧縮できることを示している。

f:id:ey_nosukeru:20190514165707p:plain:w400
Signal-to-Noiseについてエッジの分布を示したグラフ

f:id:ey_nosukeru:20190514165947p:plain:w400
辺を除去した後の予測精度の変化

Regression

f:id:ey_nosukeru:20190514170245p:plain:w500
曲線に対する回帰。左: 提案手法 右: 従来手法

従来のモデルがデータのない領域では一つの曲線にフィッティングしてしまっているのに対し、提案手法ではデータがないことによる不確実さを評価しちゃんと分散が大きくなっている。

Contextual Bandit

各ステップで食用or毒のキノコの特徴が与えられ、食用のキノコを食べる選択をした場合には+5、毒キノコを食べる選択をした場合には50%で+5か50%で-35の報酬が与えられる(食べない場合は0)という強化学習(?)の問題設定で検証を行なったもの。

f:id:ey_nosukeru:20190514170848p:plain:w500
キノコ問題におけるパフォーマンス。Regretは最適解からの乖離を表し、少ないほど良い。Greedyは強化学習における \epsilon-Greedy戦略を表している

予測に確率モデルを導入することで、問題の不確実さを評価しより良い性能が得られていることがわかる。

感想

比較的シンプルな形でニューラルネットに確率を導入できている点が興味深いと思いました。論文でもフレームワークの力を借りれば今すぐにでも実装が可能という点が強調されています。

無限のモデルをアンサンブルできるとは言いつつ結局サンプリングで1-5個しかとってこれないのは正直どうなん?とも思いましたが、Regressionで示されているようにデータ数の少ない場合や強化学習において確率的な環境を扱う場合に有用なのではないかと感じました。

やっぱりベイズ推論系の話はちゃんと数式が出てきて安心しますね笑。ガウス過程+ニューラルネットワークとかもそろそろ出てきているんでしょうか?また調べてみたいと思います。

最後までお読み頂きありがとうございました。

*1:ベイズの定理より。 P(D) wに関わらず一定なので、 wの最適化を考える上では無視できる。

*2:バイアスを含めて考えることも可能

*3:ベクトルでまとめて記述するためやや難しい記法になっていますが、それぞれが単純なガウス分布に従っているだけです