torch.nn.utils.rnn.pad_sequence()详解【Pytorch入门手册】

函数原型

torch.nn.utils.rnn.pad_sequence(sequences, batch_first=False, padding_value=0.0)

函数功能

此函数返回大小为 T x B x *B x T x * 的张量,其中 T 是最长序列的长度。

参数详解

  • sequences (list[Tensor]): 可变长度序列的列表,shape=[batch_size, N],N长度不一。
  • batch_first (bool, optional) :默认batch_size在第一维度
  • padding_value (float, optional) :填充的值,默认为0。

示例

from torch.nn.utils.rnn import pad_sequence

a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
pad_sequence([a, b, c]).size()

输出

torch.Size([25, 3, 300])

猜你喜欢

转载自blog.csdn.net/qq_38251616/article/details/125222012