まとめ
-
単語数をVとして、V^3からV^2ぐらいへ高速化した。
-
バグを見つけるなどして、定数倍の高速化にも努めた。
-
で公開している。
m2scorerとは
Grammalyのような、文法的に誤った文を文法的に正しい文に直すタスクがあり、文法誤り訂正(GEC)と呼ばれています。
M2(MaxMatch)とは文法誤り訂正の評価手法の一つです。提案されたのは少し古いのですが、CoNLL-2014と呼ばれるコンペの評価指標として採用されたこともあって、過去の実験と比較する時などには必ず用いられています。
M2はm2scorerというレポジトリでgithubに公開されています。
https://github.com/nusnlp/m2scorer
事の発端
国際学会への投稿を目指していたのですが、それにあたって大規模な実験を行う必要がありました。しかし、その実験の管理がなかなか難しく実験や評価を自動化したいなと思っていました。
しかし、自動化にあたって評価指標であるM2の計算が非常に遅いことがあり、困っていました。(特に、モデルの出力が長くなってしまった時は止まらなくなります。)
なぜ遅い?(アルゴリズム)
すでにgithubので遅いという話は問題になっていて、shotakoyamaさんらがなぜ遅いのかを議論していました。単語数をVとすると、計算時間がO(V^3)程度になる場所が2か所あるようでした。
指摘点1:ベルマンフォード
MaxMatchは下図のような文への編集を表すグラフの最短経路問題として解かれています。頂点数は、単語数に大体依存していて、最大でも単語数に2倍程度です。
最短経路問題を解く際、負の辺があるのでベルマンフォードと呼ばれるアルゴリズムが用いられており、これにはO(E*V)程度の時間がかかります。
以下、E ≒V^2 とします。(実際には、もう少し辺の数は少ないです)
Vである頂点数は単語の数程度になるのですが、これが100程度になってくると、10の6乗を超えてきて計算時間が怪しくなってきます。
図の引用元
(Dahlmeier & Ng, NAACL 2012)
指摘点2: transitive_arcs
上のようなグラフを構築する際、特殊な辺を追加するのですがこれの追加にO(V^3)かかっていました。
どう遅い?(実験)
単語数が93の入力についてそれぞれの指摘点別に時間を計測すると、以下のような計算時間になっていました。
指摘点1のベルマンフォード: 37.3秒
指摘点2(関数名transitive_arcs):48.5秒
単語数がこれの2,3倍になると計算時間は8倍、27倍程度になると推測できるので厳しいですね。
評価全体にかかる時間は1分28秒で、ベルマンフォードが2回走ることを考えると、37.3 秒*2 + 48.5 秒 = 1分23秒ということで、ここ以外にボトルネックはなさそうでした。
高速化(アルゴリズム)
指摘点1
githubのissueでも指摘されていたことですが、このグラフはDAGなので、トポロジカルソートをして先頭から順に見ていけばO(V + E)で計算を終えることができます。
6/19 追記。 トポロジカルソートで書き換えました。
が、トポロジカルソートの実装が面倒であったのでこれをサボりました。負の辺があっても今回のケースでは負の閉路が無いのでdikstraで擬似多項式時間まで落ちます。()
6/19 追記。 iet461さんのコメントの通り、以下は間違っているようです。
また、ちゃんとした証明を回していないのですが以下の用な考察があります。
今回はDAGなので、頂点が2回以上pushされる延べ回数は負の辺の数で抑えられるかと思います(多分)。負の辺の数は単語数個程度しかなさそうなので(多分)、擬似多項式時間ではなく多項式時間まで抑えられると思います。
指摘点2
辺を連結した辺を考えるのですが、この計算にO(V^3)かかっていました。以下のようなものです。
try-exceptを用いてO(V^3)で実装しているのですが冗長です。辺がある部分にだけ操作を行うが、辺がない部分には操作しないような実装に書き換えました。O(V^2)ぐらいに落ちたと思います。
O(V^3)かかる遅いコード
for k in range(len(V)):
vk = V[k]
if very_verbose:
print("v _k :", vk)
for i in range(len(V)):
vi = V[i]
if very_verbose:
print("v _i :", vi)
try:
eik = edits[(vi, vk)]
except KeyError:
continue
for j in range(len(V)):
vj = V[j]
if very_verbose:
print("v _j :", vj)
try:
ekj = edits[(vk, vj)]
except KeyError:
continue
バグっぽい遅いところ
その他の点ですが、コードが意図した通り書かれていないと思われる部分を確認しました。
これはEがリストであり、Eの要素を一つずつ見ながら、不要であった場合E.remove()でEから削除するのですが、pythonだと(2系であっても)削除した次の要素が飛ばされて実行されてしまいます。
for edge in E:
e = edits[edge]
if e[0] == 'noop' and dist[edge] > 1:
if very_verbose:
print(" remove noop arc v_i -> vj:", edge)
E.remove(edge)
dist[edge] = float('inf')
del edits[edge]
これを見つけた時、評価の値に影響を与えているんじゃないかと思ったのですが、「この部分は冗長な辺を削除する」という操作を行っており、評価の値には関係ないようです。
なお、このバグを直すことで定数倍は早くなっていると思います。
高速化(実験)
先の文章で測ると以下の様になりました。
指摘点1のベルマンフォード: 37.3秒 ー> 0.17 秒
指摘点2(関数名transitive_arcs):48.5秒 ー> 0.63 秒
まとめと今後
さらなる高速化について
指摘1の部分をトポロジカルソートに書き換えればオーダーでlogVほど速くなると思うのですが、まぁいいかなという気持ちです。
テストについて
人のバグを見つけたわけですが、自分のコードにもバグがある可能性があります。手元にあるCoNLLの結果などで少しテストをしているのですが、十分ではないでしょう。
網羅的なテストの方法を考えたいなと思っています。
ご指摘されていたさん、相談に乗ってくれた五藤さんに感謝します。
で公開しているので是非使ってみてくださいね~