LSTM单元报错

运行报ValueError: Trying to share variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel, but specified shape (1024, 2048) and found shape (595, 2048).
再执行dynamic_rnn时报的
我的tensorflow版本1.3.0,python版本3.5.2

找到原因了,构建lstm的代码lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size),当你构建多层lstm时由于都是用的这一个lstm所以就报以上错误,正确的做法是:
for i in range(num_layers):
lstm = tf.contrib.rnn.BasicLSTMCell(lstm_size)
drop = tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob)
stack_drop.append(drop)
cell = tf.contrib.rnn.MultiRNNCell(stack_drop, state_is_tuple = True)

猜你喜欢

转载自blog.csdn.net/weixin_38527080/article/details/86144519