Transformer完全に理解した

f:id:ey_nosukeru:20190622041311p:plain:w300
BLEACH』170話より

出オチです。nosukeruです。

今回はこの論文に関する解説記事になります。

arxiv.org

読んだのは結構前なのですがわからない部分が多くて放置しており、紆余曲折の末理解できた感動や躓いた部分を共有したく思い記事にまとめました。

Transformerとは

f:id:ey_nosukeru:20190622044350p:plain:w500
イメージ図

2017年12月頃の上の論文でまたもやGoogleチームにより発表され世間の注目を集めた、時系列データ処理に革新をもたらしたモデルです。

  • それまでの時系列データ処理における主役であったRNNやCNNを全く使わないアーキテクチャ
  • Attention機構、特にSelf-AttentionやMulti-Head-Attentionのみを使用
  • 並列処理を可能にしたことにより高速な学習を実現、性能も当時の最高記録を更新

などの目覚ましい特徴により多くの注目を集め、時系列データ・自然言語処理の歴史を塗り替えてしまいました。(引用数2000超え!) これ以降はこれらの分野においてはほとんどのモデルでAttentionが標準的に用いられるようになっていきます。

f:id:ey_nosukeru:20190622044119p:plain:w500
モデルアーキテクチャ

Transformerについては次の@Ryobotさんの記事に大変わかりやすくまとまっており、大いに参考にさせて頂きました。詳細についてはこちらを参照してもらえればよいかと思います。

deeplearning.hatenablog.com

この記事では、Transformerのパーツの中でも特に重要な役割を果たすAttention機構についてと、上の記事ではやや理解しづらかった、学習の仕組みや全体の処理の流れに焦点を当てて説明していきます。

Attentionとは

Attention自体はこの論文以前にあった技術で、人間が画像や文字列を見る際にその一部に注目しているという「注意」の機構を参考にし、ニューラルネットに対してもデータの一部により焦点を当てて特徴を抽出する仕組みを取り入れることで性能の向上を図ったもの、ということができます。

この論文を読む以前はそのぐらいの漠然としたイメージだったのですが、この論文では注意を「メモリデータに対するクエリによる検索」と説明していてしっくりきました。

具体的にはクエリ群 Q \in \mathbb{R}^{D \times N}、辞書のキー K \in \mathbb{R}^{D \times M}、値 V \in \mathbb{R}^{S \times M}に対し、注意によって得られる特徴ベクトルを softmax(QK^{T})Vによって計算します。これは q \in Q k \in Kのそれぞれの内積 q \cdot kを計算し、その値で重み付けして対応する v \in Vを足し合わせることに相当します。内積が大きい \iff特徴ベクトルが似ているということなので、 qに類似したレコードを検索して取ってくる作業(をソフトにしたもの)とみなすことができます。

f:id:ey_nosukeru:20190622045649p:plain
上記事からの引用。辞書 (K, V)に対し複数のクエリ Qで検索を行うイメージ

Attention自体のイメージが掴めれば他はそこまで難しくありません。Self-AttentionはAttentionの Q, K, Vに全て同じデータを入力したもの、Multi-Head Attentionは線形変換をかませてから入力し、この線形変換の組を複数用意して別々にAttentionを計算・組み合わせる機構のことです。

f:id:ey_nosukeru:20190622125223p:plain
Multi-Head Attentionの構造

図中のScaled-Dot Attentionは通常のAttentionにスケールの調整を加えたものです( \text{softmax}(\frac{QK^{T}}{\sqrt{d}})V)。理由としては dが大きくなった場合にsoftmaxの値が一部に集中せず全体的にマイルドな値にするためとのことです。

僕が論文を読んでいて気になった(分からなかった)ポイントとそれに対する答えをいくつか紹介します。

Self-Attentionは同じ位置に対する注意(softmax部分の対角成分)が一番大きくなる?

普通のAttention(かつ特徴ベクトルの大きさが大体等しい場合)では言えそうですが今回の場合は入力の前に Q, K, Vそれぞれに別の線形変換を施しているので、同じ位置の内積が必ず大きくなるとは限りません。線形変換によって特徴量のどこに注目して処理するかを決めてるイメージかなと思います。

Attentionの入力、出力はそれぞれ何?

上の説明ではAttentionをメモリの検索という風に抽象化したので実際に何が入力され、何が出力されるのかイメージするのが難しいかもしれません。

時系列データ、特に翻訳の文脈では K, Vのどちらも参照先の時系列データ X = \{x_1, x_2, ..., x_M \}を入れることが多く、 Qに対象のデータ Y = \{y_1, y_2, ..., y_N \}が入力された場合には(それぞれの特徴ベクトルは d次元ベクトルで共通)、出力として y_iのそれぞれに対するAttentionの計算値 A = \{a_1, a_2, ..., a_N \} a_iはd次元)が得られます。

例えば翻訳では X = \{"これ", "は", "ペン", "です", "。"\}が翻訳前の文章、 Y = \{"This", "is", "a", "pen", "."\}翻訳語の文章で、Xのそれぞれの要素に対して単語の類似度("is"ならば"は"、"pen"ならば"これ"や"ペン")に基づいてその特徴量を足し合わせたものがAttentionである Aということができます。

f:id:ey_nosukeru:20190622133604p:plain
(Self-)Attentionのイメージ図。関連する単語同士の重みが大きくなっている

上の図のように特にSelf-Attentionでは文章内の関連する単語を結びつけるような働きが期待できます。

で、Attentionって結局何者?

僕の場合特徴量同士の類似度を測る、ってところまでは理解できたのですがその後の特徴量を足し合わせる、っていうのが直感的にいまいちピンと来ずモヤモヤしてました。

f:id:ey_nosukeru:20190622134801p:plain
畳み込みとAttentionの比較(http://deeplearning.hatenablog.com/entry/transformerから引用)

これについても@Ryobotさんの記事がとても参考になり、Attentionを畳み込みと比較するのがわかりやすいです。つまり、近傍の要素に重みをつけて特徴量を足し合わせている畳み込みに対し、Attentionでは重みパラメータの代わりに要素同士の類似度によって重みをつけて全ての要素を足し合わせています。つまり、Attentionは「パラメータなしで類似度による重みづけを行う全範囲版の畳み込み」とみなすことができそうです。範囲が限定されないために離れた位置の依存関係を扱うことができること、「メモリの参照」による「キーと値の間の非自明な変換」の能力が、Attentionの表現力の高さの一因となっていると言えると思います。

学習・推論の流れ

学習・推論時にデータをどのように入力し、どのような出力が得られてどのように学習していくのかというのが今回の論文の中で一番理解が難しいところではないかと個人的には思っています。アーキテクチャの図を再掲し、それぞれのレイヤーでのデータがどうなっているのかとともに説明していきます。以下では、日本語を英語に翻訳するタスクを例に考えます。

f:id:ey_nosukeru:20190622044119p:plain
モデルアーキテクチャ(再掲)

Encoder部

まず、左のEncoder部により日本語の文章を何らかの特徴ベクトルに変換していきます。文章をデータとして扱わせるため、まず各単語を何らかの特徴ベクトルに変換します(埋め込み)*1

また例によって効率的な計算のためにバッチ処理を行うので、バッチサイズを b、単語数を n、埋め込みの次元を dとすると入力ベクトルの次元は (b, n, d)になります。これに位置の情報が入るような埋め込み(Positional Embedding)を足し合わせてEncoderへの入力とします。

f:id:ey_nosukeru:20190622160946p:plain
Encoder部

Encoder部ではまず入力に対してSelf-Attentionを計算します。バッチのデータ1つごと、つまり (n, d)のデータに対してSelf-Attentionを計算し、その出力の次元も (n, d)となります。よってSelf-Attention後も入力サイズは変わらず (b, n, d)となります。Attention結果と元のデータをresidual的に足し合わせ、適切に正規化して中間層のデータ (b, n, d)が得られます。

次にFeedForward層によってそれぞれの特徴ベクトルを変換します。この処理は時系列に関する計算は行わず、単純にそれぞれの特徴ベクトルを全結合層により変換します。隠れ層の次元を d_h、出力層の次元を d_oとするとデータの次元は (b, n, d) \rightarrow (b, n, d_h) \rightarrow (b, n, d_o)となります。ただしこの論文では単純に d_o = dとしているため、FeedForward層によって次元は変化せず結局 (b, n, d)となります。

この処理がEncoderのブロック分繰り返されるのですが、最終的な次元は変わらず (b, n, d)になります。これでEncoderの出力が得られました。

Decoder部

ここが最大のポイントなのですが、RNNでは現在の位置の単語を入れて次の単語についての予測結果を得るのに対し、Transformerでは1つの文章(つまり m個の単語)を受け取ってそれに対する m個の予測結果を同時に計算します

 Y = \{y_1 = "This", y_2 = "is", y_3 = "a", y_4 = "pen", y_5 = "."\}という入力を例にとって考えると、Decoderは Yを受け取った結果として \{P(y_1), P(y_2|y_1 = "This"), P(y_3|y_1 = "This", y_2 = "is"), ... \}を出力するということです。つまり i番目の出力は 1から (i - 1)番目の単語を受け取った時の i番目の単語に対する予測結果ということになります。

ただし普通にやると入力でそれぞれの予測の答え y_iが与えられてしまっていてまともな問題にならないので、 y_iを予測する際に y_1から y_{i - 1}までの結果のみを使うようにモデルに制約をかける必要があります。そのためにTransformerではAttentionの部分で未来を参照するような部分の重み(クエリ y_1に対するキー y_3の類似度)を強制的に0にすることでそのような情報の漏洩を防いでいます。

f:id:ey_nosukeru:20190622161015p:plain
Decoder部

さて、Encoderと同様データの流れを見ていくと、まずは英語文のバッチ化された単語の埋め込み (b, m, d)がDecoderに対する入力として与えられます。次に(マスクされた)Self-Attentionを通りますが、データの次元は変わらず (b, m, d)になります。この処理により予測対象の単語より前の単語列の特徴量が計算されます。

次に翻訳前の文章の特徴量を活用するため、Encoderの出力を受け取ります。EncoderからのデータをKey/Valueとして、現在の特徴量をQueryとしてAttentionをかけます。この部分のAttentionの働きは、翻訳後の文章のそれぞれの位置に対して、「その位置の前の単語の特徴量を活用しつつ、翻訳前の文章中で似た特徴量を持つ単語の位置を探してくる」ようなイメージなのかなと思っています。

Attentionの出力データの次元は変わらず (b, m, d)になります。AttentionにおいてはKeyとQueryの長さは一致する必要がないことに注意です(内積を計算するため要素の次元 dは一致する必要あり)。

FeedForward部、ブロックの繰り返しについてはEncoder部と同様です。Decoder部ではこの後単語次元 lへの線形変換がかけられ、softmaxによって各位置の単語が l個の候補単語のどれであるかの確率 P(y_i|y_1, y_2, ... y_{i - 1}) (B, m, l)次元の出力ベクトルとして得られることになります。

学習・推論の流れ

学習の際には、Encoder部への入力として日本語の文章 D_X = \{X_1, X_2, ..., X_b\}を、Decoder部への入力として D_Xに対応する答えである英語の文章 D_Y = \{Y_1, Y_2, ..., Y_b\}を与え、出力としてそれぞれの単語の生起確率 P_Y = \{p_1, p_2, ..., p_b\}が得られます。ここで p _ {ji} = P_j(y_i|y_1, ..., y _ {i - 1})がj個目のデータのi番目の位置での単語の予測結果になっています。この出力に対して正解とのCross Entropyを誤差として学習を行います。

翻訳後の文章を予測させたいのにその正解の文章をモデルに入力として与えているのがかなり奇妙に見えますが、このようにデータを一気に与えて並列計算を可能にし、GPUの性能をフルに使うことができるようにしたのがTransformerの大きな成果の一つです。

推論時には正解の文章が分からないため、Decoder部に D_Yを直接与えることができません。そのため、まずEncoder部に D_Xを入力してその出力を得て、その値を用いて P(y_i|y_1, ..., y_{i - 1})を計算していきます。まず P(y_1)を計算し*2、その確率を最大化するような単語として y_1を決定します。そして得られた y_1を結合してもう一度Decoder部に入力し、 P(y_2|y_1)から y_2を決定します。これを繰り返して文章の長さを1つずつ伸ばしていき、出力として文章の終わりを表す特殊記号が現れた時点で文章の生成を終了します(入力のサイズ: (B, 1, l) \rightarrow (B, 2, l) \rightarrow ...)。

なお、文章を伸ばしていきながら末端の単語を計算する時、副産物として P(y_1) P(y_2|y_1)は毎回得られますがこれらを計算に使うことはありません*3。この無駄な結果を毎回計算するという部分も直感に反し、Transformerの挙動を理解する上で難しかったポイントでした。

最後に

今回はTransformerを理解する上で自分が躓いたポイントを元に解説させて頂きました。少しでも理解の一助となれば幸いです。

間違っている部分などあればご指摘頂ければありがたいです。最後までお読み頂きありがとうございました!

*1:なおここの埋め込みもパラメータ化して学習するようです。

*2:この際の入力には文章の始まりを表す特殊単語を与えます。

*3:ちなみにこれらの依存する変数の内容は変わらないので、これらは毎回同じ値になります。