はじめに
transformerの仕組みについてはほとんど勉強したことが無く、どうも理解が怪しかったので改めて勉強し直すことにしました。特に気になっていたのは、並列化と推論時の計算量です。それぞれ以下のような話をしていた時に顕在化しました。
- 友人Aとの会話
- 自然言語処理が専門ではない友人にモデルの歴史を雑談程度に話していて、その時に、「LSTMとかの逐次的な系列変換モデルからtransformerになって、並列処理ができるようになったから凄い事起きたんだよ~」と言ったら、「どうやって並列化してるの?」と聞かれてごまかした。
- 先輩Bさんとの会話
- Bさんとの会話の中で「transformerの推論はO(n2)で~~(うんぬんかんぬん)」という話が出てきた。その時にあれ?そういえばなんでn3ではなくn2なんだろう?と思った。「なんか上手くやるとn2になる」みたいな話があった気がするのだが、完全に忘却していた。
これまでの自分の勉強を振り返ると、Attention機構の数式の中身などに気を取られて、それぞれの入力と出力が何なのかや、全体として何をしているかの理解が曖昧になっていた気がします。そこでこの記事では、入力から出力まで流れる行列の「サイズ」に注目して計算時間と外観を追います。
結論
- 並列処理について-> # 入力から辿る「transformerモデルが解く問題」へ
- 推論時の高速化について→ # 推論時 へ
transformer原著の図の矢印を辿りながら、行列のサイズと計算量を見ていきます。とりあえず、バッチサイズとヘッドについては考えないことにします。
まず、原著に登場するパラメータで、行列のサイズに関係するものを並べます。
- d(d_model): embedding layerとモデルの中のsub-layerの出力サイズ。論文では512。この記事ではdと表記
- d_ff: feed-forward層の次元数。論文では2048
また、入力トークン数をnとして全部の流れを追ったのがこちらの図です。以下では、それぞれの部分について細かく見ていきます。
encoder
入力としてトークン列(トークン数,1)が入ってきた後、モデルの中に入って来てからはずっと(トークン数,d_model)の行列で流れていきそのまま出ていきます。
入力部
初めに入力のトークン列X((トークン数,1)の整数列)がembeddingされ(トークン数,d_model)のHになり、これにpositional-encodingが加えられています。
- 入力(inputs)は単語を表す整数列だと思うことにする。整数列は次の様に単語を整数に直したものである。(「I want to go to college.」-> 「5, 10, 23, 1, 23, 71」)
- 計算量は (nはトークン数)
Multi_Head Attention
入力と3つのW行列の積をとり3つの行列を作成。注意機構を計算しています。
- (トークン数,d)の行列に(d,d)の行列をかけ、K,V,Q(サイズは(トークン数,d))を作成
- Attentionの計算 :
- K,V,Qの作成(行列積)が、やのVとの行列積等が
- normalizationは
Head
次元数方向(d,d_model方向)をHead数(h)個ずつに分割し計算します。一回ずつの計算は[tex: n2*(d/h)]となり、これをh回行うので理論上の計算量は変わりません。オーバーヘッドはそこそこありそうですが...
- softmaxの関数の仕様のため多様な情報を引き出すためにこうしているという話がある(自然言語処理の基礎、p151)
FeedForward
一般的な線形ネットワークです。注意機構と違い、各トークンが独立して処理されています。
- 1層のネットワークで式は
- 計算量はと、共に
- この層が具体的にどのような役割を担っているのか、私にはよく分かっていない。ここで説明する代わりにリンクを置いておく。
encoderの計算量まとめ
計算量的には、 になります。d,d_ffは数百~数千位、nは用途次第ですが一文の翻訳とかなら長くても数百程度になります。ハードウェア的な制約を無視すれば理論上どんな長さでも入力とすることができます。
decoder(訓練時)
入力から最後のlinearに出るまで行列サイズはencoder同様(n,d)です。最後のlinearの計算で(n,d)から(n,vocabsize)に変わります。
- 最後のlinearでサイズをd(d_model)から単語の種類数(vocabsize)に変える
- 例えば、5000種類の単語で出力を構成するとしたら、1tokenあたり5000個の数値を出し、softmaxをかけてそれぞれの数値がの5000種類それぞれの単語の確率だと考える
- 計算量は、
Masked Multi-Head Attention
基本的な計算はMulti-Head Attentionと殆ど変わらないのですが、文全体を同時に計算するために積の一部をマスクするという仕組みを持っています。
transformerは先の単語を予測するタスクなのに文全体でモデルに一回しか情報が流れないという構造を持っています。(# 入力から辿る「transformerモデルが解く問題」で後述)
訓練時にの要素の一部を消すことで先の単語の情報が入らないようにしています
Multi-Head Attention
入力の (サイズは(n,d))はencoderから来て、それらとから来たでAttentionを計算しています。
- エンコーダー側の入力とデコーダー側の入力のサイズが違う場合はpaddingなどすればよいのであまり問題ではない
- 計算量はencoderと同じ
Feedworward
encoderと同じなので省略
計算量についてのまとめ
headを考えないととなります。multiheadになると、部分の計算が変わりますが、計算量の理論値は変わらないでしょう。
- nが数千ぐらいだと思うと、vocabsizeが意外に大きくて用途によってはここがボトルネックになるかも...とは思う。そもそも、埋め込み行列がメモリ的にも結構食うという話(Takase,Kobayashi 2020)もある。将来モデルの中身が小さくなっていったら、このあたりがボトルネックになるかも。
transformerは『一本の入力文(:「私は猫を飼っている」)と途中までの出力文「I have a」()から、「cat」を予測する』という問題(i=2)を、すべてのiについて同時に並列に解いています。数式に書き表すと以下の様になります。
モデルに対する入力は『x:「私は猫を飼っている」』ではなく、『:「私は猫を飼っている」+ 「I have a」』となります。
- なんとなく、「入力を入れると出力が出る」みたいな感覚があったのですがまずかったです。transformerも本質的には、「次の単語を予測する」だけの機能を持っていて、それを並列化している、と思うと腑に落ちました。
- この並列化はMasked Multi-Head Attentionで上手い事やって、文章全体を一回で扱えるdecoder機構を作った、と言えると思います。
推論時
入力時は上手く並列化して文全体を一回モデルに流すだけで計算出来ていますが、推論時はそうはいきません。decoderモデルを出力単語数分だけ回すことになり、がn回かかり、nに注目するとn3の計算量がかかります。しかしながら、実はここでもうまいことやってdecoderモデルの計算がO(n2)より軽くできます
- こちらの記事(Lei Mao Transformer Autoregressive Inference Optimization)に説明されていることを解説する。
- これがPytorchなどの標準実装になっているのかはよく分からない(todo)
- n2かかる部分というのはそもそも、Attention機構のの行列積と、これから出てきたものにVをかける部分なので、[tex: softmax(Q{n+1}K{n+1}^{T})V_{n+1}について考える。
- 今求めたいのは、この計算結果の中で一番下の行だけであることを考えると、実は計算しなくてよい部分が多く、O(n)に落ちる
おわりに
最初の疑問に今自分が答えるとすれば以下のようになる。
- 友人Aとの会話
- transformerは従来のモデルと比べてモデルの能力自体は失わずに文全体を同時に扱う機構を持っているため高速に計算できる
- 先輩Bとの会話
- 推論時は実は計算しなくてよい部分が多く、decoderモデル一回当たりの計算がO(n)に落ちていて、文全体でO(n2)になっている
あんまり自信はないんで間違いがあったら教えてください。
空間計算量
系列変換モデルとattention機構
todo
GPT
todo
BERT
todo