【pytorch】nn.utils.rnn.pad_sequence的使用


错误

The size of tensor a (3) must match the size of tensor b (5) at non-singleton dimension 1

在使用nn.utils.rnn.pad_sequence时,遇到如上错误,原因是使用方式错误.

使用说明

用padding_value填充可变长度张量列表
pad_sequence 沿新维度堆叠张量列表,
并将它们垫成相等的长度。
例如,如果输入是列表
大小为“L x *”的序列,如果batch_first为False,并且“T x B x *”

“B”是批量大小。它等于“序列”中元素的数量。
“T”是最长序列的长度。
“L”是序列的长度。
“*”是任意数量的尾随维度,包括没有。

例子:
    >>> from torch.nn.utils.rnn import pad_sequence
    >>> a = torch.ones(25, 300)
    >>> b = torch.ones(22, 300)
    >>> c = torch.ones(15, 300)
    >>> pad_sequence([a, b, c]).size()
    torch.Size([25, 3, 300])

注意:
    该函数返回大小为“T x B x *”或“B x T x *”的张量
    其中“T”是最长序列的长度。该函数假设
    序列中所有张量的尾随维度和类型都是相同的。

参数:
    序列 (list[Tensor]):可变长度序列的列表。
    batch_first(bool,可选):如果为 True,输出将在“B x T x *”中,否则在
        ``T x B x *`` 否则。默认值:假。
    padding_value (float,可选):填充元素的值。默认值:0。

返回:
    如果:attr:`batch_first` 为``False``,则大小为``T x B x *`` 的张量。
    大小为“B x T x *”的张量,否则反过来

样例代码

最后一维必须一致,可以理解为embeding层

from torch import nn
import torch

a = torch.randn(3,5)
b = torch.randn(2,5)

out = nn.utils.rnn.pad_sequence([a,b])
print(out)

当维度大于2时, 一般会包含batch size,所以要指定batch_size是否是第一维度

from torch import nn
import torch

a = torch.randn(4,3,5)
b = torch.randn(2,3,5)

out = nn.utils.rnn.pad_sequence([a,b], batch_first=False)
print(out)

猜你喜欢

转载自blog.csdn.net/mimiduck/article/details/131358718