在创建预测数据时报错data and recursive_seq_lens do not match

  • 关键字:张量数据维度

  • 问题描述:在使用fluid.create_lod_tensor创建一个词向量预测输出,在执行创建的时候报错,数据维度和参数recursive_seq_lens不匹配。

  • 报错信息:

<ipython-input-7-422a1a374a70> in infer(use_cuda, inference_program, params_dirname)
     17     lod = [[2]]
     18 
---> 19     first_word = fluid.create_lod_tensor(data1, lod, place)
     20     second_word = fluid.create_lod_tensor(data2, lod, place)
     21     third_word = fluid.create_lod_tensor(data3, lod, place)

/usr/local/lib/python3.5/dist-packages/paddle/fluid/lod_tensor.py in create_lod_tensor(data, recursive_seq_lens, place)
     74         assert [
     75             new_recursive_seq_lens
---> 76         ] == recursive_seq_lens, "data and recursive_seq_lens do not match"
     77         flattened_data = np.concatenate(data, axis=0).astype("int64")
     78         flattened_data = flattened_data.reshape([len(flattened_data), 1])

AssertionError: data and recursive_seq_lens do not match
  • 问题复现:定义一个形状为(1, 1)的整型数据,然后再定义一个(1, 1)的表示列表的长度,这个设置为2,最后在执行fluid.create_lod_tensor接口创建预测数据的时候报错。错误代码如下:
data1 = [[211]]  # 'among'
data2 = [[6]]  # 'a'
data3 = [[96]]  # 'group'
data4 = [[4]]  # 'of'
lod = [[2]]

first_word = fluid.create_lod_tensor(data1, lod, place)
second_word = fluid.create_lod_tensor(data2, lod, place)
third_word = fluid.create_lod_tensor(data3, lod, place)
fourth_word = fluid.create_lod_tensor(data4, lod, place)
  • 解决问题:recursive_seq_lens这个参数是指输入的数据列表的长度信息,我们输入的数据的长度是1,所以参数recursive_seq_lens也应该是的值也应该是[[1]],而不是输入数据的维度数量。正确代码如下:
data1 = [[211]]  # 'among'
data2 = [[6]]  # 'a'
data3 = [[96]]  # 'group'
data4 = [[4]]  # 'of'
lod = [[1]]

first_word = fluid.create_lod_tensor(data1, lod, place)
second_word = fluid.create_lod_tensor(data2, lod, place)
third_word = fluid.create_lod_tensor(data3, lod, place)
fourth_word = fluid.create_lod_tensor(data4, lod, place)
  • 问题拓展:fluid.create_lod_tensor接口也支持多维不同长度的数据,如:创建一个张量来表示两个句子,一个是2个单词,一个是3个单词。那就需要设置recursive_seq_lens参数的值为[[2,3]]

猜你喜欢

转载自blog.csdn.net/PaddlePaddle/article/details/87929265