参考:
qiita.com
続きです
ミニバッチ作成する箇所のコードなんですけど、
# minibatch作成 minibatch = data[num*BATCH_SIZE: (num+1)*BATCH_SIZE] # 読み込み用のデータ作成 enc_words, dec_words = make_minibatch(minibatch)
dataを適当に分割抽出して、処理を施しenc_words, dec_wordsに突っ込んでますね。この2つはただのリストだと思われる。ndarray
こちらがmake_minibatch
def make_minibatch(minibatch): # enc_wordsの作成 enc_words = [row[0] for row in minibatch] enc_max = np.max([len(row) for row in enc_words]) enc_words = np.array([[-1]*(enc_max -len(row)) + row for row in enc_words], dtype='int32') enc_words = enc_words.T # dec_wordsの作成 dec_words = [row[1] for row in minibatch] dec_max = np.max([len(row) for row in dec_words]) dec_words = np.array([row + [-1]*(dec_max - len(row)) for row in dec_words], dtype='int32') dec_words = dec_words.T return enc_words, dec_words
ちょっとよくわからないのでenc_words, dec_wordsの一番最後の文から見ていきます。
enc_words = enc_words.T dec_words = dec_words.T
Tってなんぞ?Variableのやつなのかnumpyなのか…
→numpyのっぽいですね。普通に転置。
numpy.ndarray.T — NumPy v1.14 Manual
2番目と3番目の文、
enc_max = np.max([len(row) for row in enc_words]) enc_words = np.array([[-1]*(enc_max -len(row)) + row for row in enc_words], dtype='int32') dec_max = np.max([len(row) for row in dec_words]) dec_words = np.array([row + [-1]*(dec_max - len(row)) for row in dec_words], dtype='int32')
enc_words, dec_wordsの各行(各文)をnumpyで変換しているところ。各文のサイズが全て一緒になるように-1でパディングしているようで。ということはこの時点で単語列に分解されているものを使っているのね。
1番目の文、
enc_words = [row[0] for row in minibatch] dec_words = [row[1] for row in minibatch]
minibatchもといdataのrow[0]がenc_words、row[1]がdec_wordsなので、
[[[row0[0]],[row0[1]]], [[row1[0]],[row1[1]]], [[row2[0]],[row2[1]]], ... ]
みたいな感じにデータを整形すればよいということがわかりました(まあこの辺は処理ができるなら好きなようにやればいいとは思いますがね)。
いずれにせよめんどっちいかも…
次:第6回 コーパス整形