版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/guotong1988/article/details/81941879
通过阅读 github.com/tensorflow/models/tree/master/tutorials/rnn/ptb
在ptb_word_lm.py
文件,看到最终模型输入存在class PTBInput
里,
进而追踪到reader.py
文件的ptb_producer
方法最后明显就是一句话错位作为RNN的输入
x = tf.strided_slice(data, [0, i * num_steps],
[batch_size, (i + 1) * num_steps])
y = tf.strided_slice(data, [0, i * num_steps + 1],
[batch_size, (i + 1) * num_steps + 1])
return x, y
一般RNN的输入方式:
t=0 t=1 t=2 t=3 t=4
[The, brown, fox, is, quick]
[The, red, fox, jumped, high]
words_in_dataset[0] = [The, The]
words_in_dataset[1] = [brown, red]
words_in_dataset[2] = [fox, fox]
words_in_dataset[3] = [is, jumped]
words_in_dataset[4] = [quick, high]
batch_size = 2, time_steps = 5
处理成x,y应该就是
x=[_start_, The, brown, fox, is, quick]
y=[The, brown, fox, is, quick, _end_]
x=[_start_, The, red, fox, jumped, high]
y=[The, red, fox, jumped, high, _end_]
wiki上对Language model的解释:A statistical language model is a probability distribution over sequences of words, given some linguistic context
对应这里就是输入一个x序列,对y序列长成什么样的预测