英語帝国を打倒しよう

言語の壁に計算機で挑もう!

入力と出力から辿るtransformer(計算量等)

はじめに

transformerの仕組みについてはほとんど勉強したことが無く、どうも理解が怪しかったので改めて勉強し直すことにしました。特に気になっていたのは、並列化と推論時の計算量です。それぞれ以下のような話をしていた時に顕在化しました。

  • 友人Aとの会話
    • 自然言語処理が専門ではない友人にモデルの歴史を雑談程度に話していて、その時に、「LSTMとかの逐次的な系列変換モデルからtransformerになって、並列処理ができるようになったから凄い事起きたんだよ~」と言ったら、「どうやって並列化してるの?」と聞かれてごまかした。
  • 先輩Bさんとの会話
    • Bさんとの会話の中で「transformerの推論はO(n2)で~~(うんぬんかんぬん)」という話が出てきた。その時にあれ?そういえばなんでn3ではなくn2なんだろう?と思った。「なんか上手くやるとn2になる」みたいな話があった気がするのだが、完全に忘却していた。

これまでの自分の勉強を振り返ると、Attention機構の数式の中身などに気を取られて、それぞれの入力と出力が何なのかや、全体として何をしているかの理解が曖昧になっていた気がします。そこでこの記事では、入力から出力まで流れる行列の「サイズ」に注目して計算時間と外観を追います。

結論

  • 並列処理について-> # 入力から辿る「transformerモデルが解く問題」へ
  • 推論時の高速化について→ # 推論時 へ

transformer(訓練時)

transformer原著の図の矢印を辿りながら、行列のサイズと計算量を見ていきます。とりあえず、バッチサイズとヘッドについては考えないことにします。

まず、原著に登場するパラメータで、行列のサイズに関係するものを並べます。

  • d(d_model): embedding layerとモデルの中のsub-layerの出力サイズ。論文では512。この記事ではdと表記
  • d_ff: feed-forward層の次元数。論文では2048

また、入力トークン数をnとして全部の流れを追ったのがこちらの図です。以下では、それぞれの部分について細かく見ていきます。

Transformer(Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).)

encoder

入力としてトークン列(トークン数,1)が入ってきた後、モデルの中に入って来てからはずっと(トークン数,d_model)の行列で流れていきそのまま出ていきます。

encoder全体

入力部

初めに入力のトークン列X((トークン数,1)の整数列)がembeddingされ(トークン数,d_model)のHになり、これにpositional-encodingが加えられています。

  • 入力(inputs)は単語を表す整数列だと思うことにする。整数列は次の様に単語を整数に直したものである。(「I want to go to college.」-> 「5, 10, 23, 1, 23, 71」)
  • 計算量は O(n) (nはトークン数)

入力部

Multi_Head Attention

入力と3つのW行列の積をとり3つの行列を作成。注意機構を計算しています。

  • (トークン数,d)の行列に(d,d)の行列をかけ、K,V,Q(サイズは(トークン数,d))を作成
  • Attentionの計算 :  softmax(\frac{QK^{T}}{\sqrt{D}})V
  • K,V,Qの作成(行列積)が O(d^{2} \times n) Q×K^{T} softmax(\frac{QK^{T}}{\sqrt{D}})のVとの行列積等が O(d * n^{2})
  • normalizationは O(n*d)

Multi-Head Attention

次元数方向(d,d_model方向)をHead数(h)個ずつに分割し計算します。一回ずつの計算は[tex: n2*(d/h)]となり、これをh回行うので理論上の計算量は変わりません。オーバーヘッドはそこそこありそうですが...

  • softmaxの関数の仕様のため多様な情報を引き出すためにこうしているという話がある(自然言語処理の基礎、p151)

Head

FeedForward

一般的な線形ネットワークです。注意機構と違い、各トークンが独立して処理されています。

  • 1層のネットワークで式は Relu(H'' × W_1) × W_2
  • 計算量は H'' × W1と、 (Relu(H'' × W_1)) × W_2共に O(n * d * d_{ff})
  • この層が具体的にどのような役割を担っているのか、私にはよく分かっていない。ここで説明する代わりにリンクを置いておく。

Feed_forward

encoderの計算量まとめ

計算量的には、 O(d^{2} * n + d * n^{2} + n * d * d_{ff}) になります。d,d_ffは数百~数千位、nは用途次第ですが一文の翻訳とかなら長くても数百程度になります。ハードウェア的な制約を無視すれば理論上どんな長さでも入力とすることができます。

decoder(訓練時)

入力から最後のlinearに出るまで行列サイズはencoder同様(n,d)です。最後のlinearの計算で(n,d)から(n,vocabsize)に変わります。

  • 最後のlinearでサイズをd(d_model)から単語の種類数(vocabsize)に変える
    • 例えば、5000種類の単語で出力を構成するとしたら、1tokenあたり5000個の数値を出し、softmaxをかけてそれぞれの数値がの5000種類それぞれの単語の確率だと考える
    • 計算量は、 O(n * d * vocabsize)
      decoder

Masked Multi-Head Attention

基本的な計算はMulti-Head Attentionと殆ど変わらないのですが、文全体を同時に計算するために QK^{T}積の一部をマスクするという仕組みを持っています。

Masked Multi-Head Attention

transformerの並列化の仕組み

transformerは先の単語を予測するタスクなのに文全体でモデルに一回しか情報が流れないという構造を持っています。(# 入力から辿る「transformerモデルが解く問題」で後述)

訓練時に Q*K^{T}(=A)の要素の一部を消すことで先の単語の情報が入らないようにしています

Multi-Head Attention

入力の K',V' (サイズは(n,d))はencoderから来て、それらと \bar{H''}から来た QでAttentionを計算しています。

  • エンコーダー側の入力とデコーダー側の入力のサイズが違う場合はpaddingなどすればよいのであまり問題ではない
  • 計算量はencoderと同じ

decoderのMulti-Head Attention

Feedworward

encoderと同じなので省略

計算量についてのまとめ

headを考えないと O(n * d * vocabsize + d^{2} * n + d * n^{2} + n * d * d_{ff})となります。multiheadになると、 softmax(\frac{QK^{T}}{\sqrt{D}})部分の計算が変わりますが、計算量の理論値は変わらないでしょう。

入力から辿る「transformerモデルが解く問題」

transformerは『一本の入力文( x:「私は猫を飼っている」)と途中までの出力文「I have a」( y_{0-2})から、「cat」を予測する』という問題(i=2)を、すべてのiについて同時に並列に解いています。数式に書き表すと以下の様になります。

モデルに対する入力は『x:「私は猫を飼っている」』ではなく、『 x:「私は猫を飼っている」+  y_{0-2}「I have a」』となります。

  • なんとなく、「入力を入れると出力が出る」みたいな感覚があったのですがまずかったです。transformerも本質的には、「次の単語を予測する」だけの機能を持っていて、それを並列化している、と思うと腑に落ちました。
  • この並列化はMasked Multi-Head Attentionで上手い事やって、文章全体を一回で扱えるdecoder機構を作った、と言えると思います。

推論時

入力時は上手く並列化して文全体を一回モデルに流すだけで計算出来ていますが、推論時はそうはいきません。decoderモデルを出力単語数分だけ回すことになり、 O(d^{2} * n + d * n^{2} + n * d * d_ff + n * d * vocabsize)がn回かかり、nに注目するとn3の計算量がかかります。しかしながら、実はここでもうまいことやってdecoderモデルの計算がO(n2)より軽くできます

  • こちらの記事(Lei Mao Transformer Autoregressive Inference Optimization)に説明されていることを解説する。
    • これがPytorchなどの標準実装になっているのかはよく分からない(todo)
  • n2かかる部分というのはそもそも、Attention機構の QK^{T}の行列積と、これから出てきたものにVをかける部分なので、[tex: softmax(Q{n+1}K{n+1}^{T})V_{n+1}について考える。
  • 今求めたいのは、この計算結果の中で一番下の行だけであることを考えると、実は計算しなくてよい部分が多く、O(n)に落ちる

おわりに

最初の疑問に今自分が答えるとすれば以下のようになる。

  • 友人Aとの会話
    • transformerは従来のモデルと比べてモデルの能力自体は失わずに文全体を同時に扱う機構を持っているため高速に計算できる
  • 先輩Bとの会話
    • 推論時は実は計算しなくてよい部分が多く、decoderモデル一回当たりの計算がO(n)に落ちていて、文全体でO(n2)になっている

あんまり自信はないんで間違いがあったら教えてください。

空間計算量

系列変換モデルとattention機構

todo

GPT

todo

BERT

todo