5-5-3 RNNはどう学習するのか

前の章では、RNNが、過去の情報を少し持ちながら、順番にデータを読むネットワークであることを見てきました。
文章、音声、株価、気温。
こうしたデータでは、今この瞬間だけを見るのではなく、その前に何があったかがとても大切です。
では、そのようなRNNは、どうやって学習しているのでしょうか。

ここで大事になるのが、

・BPTT
・教師強制

という考え方です。

BPTT(Backpropagation Through Time)

長い物語を読んでいる場面を想像してみてください。

あなたは一文ずつ読み進めていきます。
そのとき、前の文の内容を少し覚えながら、
次の文を理解していきます。

でも最後まで読んだところで、

「あれ、この結末になるなら、途中の読み取り方が少し違っていたかもしれない」

と気づくことがあります。

すると私たちは、

・どこで意味を取り違えたのか
・どの場面の理解が後の流れに影響したのか

を、後ろからたどって考えますよね。

RNNの学習も、これと少し似ています。
まずデータを順番に読んでいく。
そして最後に誤差を見て、その間違いが、どの時点から生まれたのかを過去へさかのぼってたどるのです。

この「時間をさかのぼって誤差を伝える」方法を、BPTT(Backpropagation Through Time)と呼びます。
名前だけだと少し難しそうだけれど、ここではこう考えるとわかりやすいです。

BPTTは、誤差逆伝播法のRNN版

以前見たふつうの誤差逆伝播法は、ネットワークの層を後ろ向きにたどって誤差を伝えていくものでした。
でもRNNでは、ネットワークは層だけでなく、時間方向にもつながっています

たとえば文章なら、

・1語目を読んだときの状態
・2語目を読んだときの状態
・3語目を読んだときの状態

が、時間の流れに沿って続いています。

だからRNNでは、

・層をさかのぼる
・さらに時間もさかのぼる

必要があります。
つまりBPTTとは、誤差逆伝播法を、時間の流れをもつRNN向けにした学習方法なのです。

RNNは見た目には一つのネットワークですが、学習のときにはそれを

1時刻目
2時刻目
3時刻目
4時刻目…

というふうに、時間方向にずらっと並べて考えることがあります。
文章なら、一語ずつ並べて考える感じです。
そうするとRNNは、「時間に沿って長く伸びたネットワーク」のように見えます。
BPTTでは、この時間方向に並んだネットワークを最後の時刻から最初のほうへ向かって誤差を伝えていくのです。
つまり、

・今の間違いが、少し前の状態にどう関係していたか
・さらにその前にどうつながっていたか

を、順番にたどっていくわけです。

RNNでは、今の出力が今の入力だけで決まるわけではありません。
たとえば、「今日は とても 寒い」という文なら、最後の「寒い」の理解には「今日は」「とても」という前の流れが関わっています。
もし最後の予測が間違っていたなら、その原因は「寒い」のところだけではなく、もっと前の単語の受け取り方にあるかもしれません。
だからRNNは、最後の間違いを見て終わるのではなく、過去までさかのぼって学び直す必要があるのです。
これは、物語の結末がしっくりこなかったときに、最後の一文だけでなく前の章まで戻って読み直すのに似ています。

教師強制

ここで、RNNの学習でもう一つ大事なのが、教師強制(teacher forcing)です。
この言葉の「教師」は、その場で人間が横について教える、という意味ではありません。
ここでいう教師とは、学習データの中に入っている正解のことです。
つまり教師強制とは、モデルの前の予測ではなく、正解データを次の入力として使いながら学習する方法です。

これを、子どもの音読練習で考えてみましょう。

子どもが文を読んでいて、「今日は とても あおい…」と読み間違えてしまったとします。
このまま、その間違った読みをもとに次の語まで読ませると、どんどん文全体がおかしくなっていくかもしれません。
でも先生が横で、「ここは『寒い』だよ。じゃあ次を読んでみよう」と正しい語を示してくれたら、文の流れを崩さずに練習できますよね。
教師強制は、これとよく似ています。

RNNが前の時刻で少し間違った予測をしても、次の時刻にはその予測を使うのではなく、本来の正解を入力として与えるのです。

RNNの学習の初めのころは、予測がまだ不安定です。
もしその不安定な予測をそのまま次の入力に使ってしまうと、小さな間違いが次の間違いを呼び、さらにその次も崩れていく、ということが起こります。
つまり、
・最初の小さなミス
・そのミスをもとにした次のミス
・さらに広がるミス
というふうに、間違いが連鎖しやすくなるのです。
教師強制では、その連鎖をいったん止めて、正しい流れを見せながら学習させることができます。
だからRNNは、学習の初期でも比較的安定して学びやすくなります。

ただし教師強制には、注意点もあります。
学習中はいつも正解を見せてもらっていたのに、実際に文章生成などで使うときには、モデルは自分の予測だけを頼りに次を出していかなければなりません。

つまり、

学習中→ 正解に導いてもらえる
実際に使うとき→ 自分で流れをつくる

という違いがあります。
この差が、RNNの学習の難しさの一つでもあります。

まとめ

ここまでを見ると、RNNの学習は、少し複雑に見えるかもしれません。
でも本質はシンプルです。
RNNは、今この瞬間だけを直すのではなく、流れ全体の中でどこがずれていたかを学ぶネットワークなのです。

そのために、

・BPTTで時間をさかのぼって誤差を伝え
・教師強制で正しい流れを見せながら練習する

という工夫が使われます。

BPTT
→ 誤差逆伝播法のRNN版
→ 誤差を層だけでなく時間方向にもさかのぼって伝える方法

教師強制
→ モデルの前の予測ではなく、学習データに含まれる正解を次の入力として使いながら学習する方法

next ▶ RNNの課題と発展