paddle.summary 的使用问题

今天在使用paddle.summary打印模型的时候出现了下面的错误:

---------------------------------------------------------------------------ValueError                                Traceback (most recent call last)<ipython-input-155-ec604926d5fd> in <module>
      5 recall_model=DNNRecallLayer(sparse_feature_number=600000, sparse_feature_dim=9, fc_sizes=fc_sizes)
      6 
----> 7 param_info = paddle.summary(recall_model,input_size=[(1,),(4,),(3,),(1,)])
      8 print(param_info)
      9 
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model_summary.py in summary(net, input_size, dtypes)
    147 
    148     _input_size = _check_input(_input_size)
--> 149     result, params_info = summary_string(net, _input_size, dtypes)
    150     print(result)
    151 
</opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/decorator.py:decorator-gen-342> in summary_string(model, input_size, dtypes)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/base.py in _decorate_function(func, *args, **kwargs)
    313         def _decorate_function(func, *args, **kwargs):
    314             with self:
--> 315                 return func(*args, **kwargs)
    316 
    317         @decorator.decorator
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/hapi/model_summary.py in summary_string(model, input_size, dtypes)
    274 
    275     # make a forward pass
--> 276     model(*x)
    277 
    278     # remove these hooks
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in __call__(self, *inputs, **kwargs)
    889                 self._built = True
    890 
--> 891             outputs = self.forward(*inputs, **kwargs)
    892 
    893             for forward_post_hook in self._forward_post_hooks.values():
<ipython-input-128-9651e9a4aa73> in forward(self, batch_size, user_sparse_inputs, mov_sparse_inputs, label_input)
     60         user_sparse_embed_seq = []
     61         for s_input in user_sparse_inputs:
---> 62             emb = self.embedding(s_input)
     63             emb = paddle.reshape(emb, shape=[-1, self.sparse_feature_dim])
     64             user_sparse_embed_seq.append(emb)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py in __call__(self, *inputs, **kwargs)
    889                 self._built = True
    890 
--> 891             outputs = self.forward(*inputs, **kwargs)
    892 
    893             for forward_post_hook in self._forward_post_hooks.values():
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/layer/common.py in forward(self, x)
   1288             padding_idx=self._padding_idx,
   1289             sparse=self._sparse,
-> 1290             name=self._name)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/nn/functional/input.py in embedding(x, weight, padding_idx, sparse, name)
    200         return core.ops.lookup_table_v2(
    201             weight, x, 'is_sparse', sparse, 'is_distributed', False,
--> 202             'remote_prefetch', False, 'padding_idx', padding_idx)
    203     else:
    204         helper = LayerHelper('embedding', **locals())
ValueError: (InvalidArgument) Tensor holds the wrong type, it holds float, but desires to be int64_t.
  [Hint: Expected valid == true, but received valid:0 != true:1.] (at /paddle/paddle/fluid/framework/tensor_impl.h:33)
  [operator < lookup_table_v2 > error]

我的代码为:

# 定义训练的轮次
epochs=3
# 定义模型
fc_sizes=[512, 256, 128, 32]
recall_model=DNNRecallLayer(sparse_feature_number=600000, sparse_feature_dim=9, fc_sizes=fc_sizes)
param_info = paddle.summary(recall_model,input_size=[(2,),(4,),(3,),(1,)])
print(param_info)

解决方法

# 定义训练的轮次
epochs=3
# 定义模型
fc_sizes=[512, 256, 128, 32]
recall_model=DNNRecallLayer(sparse_feature_number=600000, sparse_feature_dim=9, fc_sizes=fc_sizes)
param_info = paddle.summary(recall_model,input_size=[(2,),(4,),(3,),(1,)],dtypes=int)
print(param_info)

这样就行了,需要设置一个dtypes参数

猜你喜欢

转载自blog.csdn.net/w5688414/article/details/113770850