リムナンテスは愉快な気分

徒然なるままに、言語、数学、音楽、プログラミング、時々人生についての記事を書きます

バッチ処理

Chainerで機械翻訳 (python) 第5回

参考:
qiita.com

 

続きです

ミニバッチ作成する箇所のコードなんですけど、

# minibatch作成
minibatch = data[num*BATCH_SIZE: (num+1)*BATCH_SIZE]
# 読み込み用のデータ作成
enc_words, dec_words = make_minibatch(minibatch)

dataを適当に分割抽出して、処理を施しenc_words, dec_wordsに突っ込んでますね。この2つはただのリストだと思われる。ndarray

 

こちらがmake_minibatch

def make_minibatch(minibatch):
  # enc_wordsの作成
  enc_words = [row[0] for row in minibatch]
  enc_max = np.max([len(row) for row in enc_words])
  enc_words = np.array([[-1]*(enc_max -len(row)) + row for row in enc_words], dtype='int32')
  enc_words = enc_words.T

  # dec_wordsの作成
  dec_words = [row[1] for row in minibatch]
  dec_max = np.max([len(row) for row in dec_words])
  dec_words = np.array([row + [-1]*(dec_max - len(row)) for row in dec_words], dtype='int32')
  dec_words = dec_words.T

  return enc_words, dec_words

ちょっとよくわからないのでenc_words, dec_wordsの一番最後の文から見ていきます。

enc_words = enc_words.T

dec_words = dec_words.T

Tってなんぞ?Variableのやつなのかnumpyなのか…

→numpyのっぽいですね。普通に転置。

numpy.ndarray.T — NumPy v1.14 Manual
 

2番目と3番目の文、

  enc_max = np.max([len(row) for row in enc_words])
  enc_words = np.array([[-1]*(enc_max -len(row)) + row for row in enc_words], dtype='int32')

  dec_max = np.max([len(row) for row in dec_words])
  dec_words = np.array([row + [-1]*(dec_max - len(row)) for row in dec_words], dtype='int32')

enc_words, dec_wordsの各行(各文)をnumpyで変換しているところ。各文のサイズが全て一緒になるように-1でパディングしているようで。ということはこの時点で単語列に分解されているものを使っているのね。


1番目の文、

enc_words = [row[0] for row in minibatch]

dec_words = [row[1] for row in minibatch]

minibatchもといdataのrow[0]がenc_words、row[1]がdec_wordsなので、
[[[row0[0]],[row0[1]]], [[row1[0]],[row1[1]]], [[row2[0]],[row2[1]]], ... ]
みたいな感じにデータを整形すればよいということがわかりました(まあこの辺は処理ができるなら好きなようにやればいいとは思いますがね)。
いずれにせよめんどっちいかも…

次:第6回 コーパス整形