リムナンテスは愉快な気分

徒然なるままに、言語、数学、音楽、プログラミング、時々人生についての記事を書きます

学習器

Chainerで機械翻訳 (python) 第4回

参考:
qiita.com

学習部分のコードをがりがり書いていきますよ。
vocab_sizeは参考元だと「辞書」とやらを使っているがよくわからないライブラリの何かだと思われるんで、
今回は定数扱い(あとでnltkで書き直すかも)。

まずはSeq2Seqのインスタンスを作成、モデルの初期設定。

    # モデル(Seq2Seq)のインスタンス化
    model = Seq2Seq(vocab_size=VOCAB_SIZE,
                    embed_size=EMBED_SIZE,
                    hidden_size=HIDDEN_SIZE,
                    batch_size=BATCH_SIZE)
    # モデルの初期化
    model.reset()
    # GPU or CPU(今回はCPUのみ実装)
    ARR = np

    # 学習開始
    for epoch in range(EPOCH_NUM):
        # epochごとにoptimizerの初期化
        opt = optimizers.Adam()
        # modelをoptimizerにセット
        opt.setup(model)
        # 勾配調整
        opt.add_hook(optimizer.GradientClipping(5))

        # 学習データ読み込み
        data = # うまいこと読み込む
        # データのシャッフル
        random.shuffle(data)

optimizerというのは、(よくわからないんですが)パラメータ、勾配を勝手にいい感じに求めてくれるようなやつらしいです。とりあえず突っ込んでおきましょう。
Introduction to Chainer — Chainer 3.4.0 documentation

そして学習データをimportしたかったのですが、参考元のdataって具体的にどういう形式なんですかね、よくわからん。書いてほしかった。
とりあえず先に進んで、何をしているかだけ追うことにします。minibatchあたりで実際にencoder、decoderに入れて処理しようとしてたりするので、そこで判断したいと思います。

バッチ学習部をガリガリ書きますよ~

        # バッチ学習スタート
        for num in range(len(data)//BATCH_SIZE):
            # minibatch作成
            minibatch = data[num*BATCH_SIZE: (num+1)*BATCH_SIZE]
            # 読み込み用のデータ作成
            enc_words, dec_words = make_minibatch(minibatch)
            # 順伝搬で損失計算
            total_loss = forward(enc_words=enc_words,
                                 dec_words=dec_words,
                                 model=model,
                                 ARR=ARR)
            # 誤差逆伝搬で勾配計算
            total_loss.backward()
            # 計算した勾配を使ってネットワークを更新
            opt.update()
            # 記録された勾配を初期化
            opt.zero_grads()
        # epochごとにモデル保存
        serializers.save_hdf5(outputpath, model)

//は切り捨て除算です。(雑魚なので知らなかった…)
細かい動作原理は後回しにして、まずは概形から。
学習データを一定の単位で分けて、各単位ごとに学習をしていきます(minibatch)
で、順伝搬(損失計算)→逆伝搬(勾配計算)→ネットワーク更新の順。

forward()は第3回のときに作った関数です。
現状の機械(model=Seq2Seq)にデータ(enc_words)を流し込んだ結果と正解データ(dec_words)を比較して正誤判定する関数でしたね。確か。
で、その正誤判定をもとに勾配計算、重みの更新をすると。

順伝搬、逆伝搬とパラメータの更新については、このあたりの記事がとても詳しく書かれていて分かり易い。
Mind で Neural Network (準備編2) 順伝播・逆伝播 図解 - Qiita


データの流し方、ミニバッチらへんは長くなりそうなので次回。

次:第5回 バッチ処理