pytorch 之pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence使用

pad_sequence

该函数用padding_value来填充一个可变长度的张量列表。将长度较短的序列填充为和最长序列相同的长度。
一句话就是:填充句子到相同长度。
参数说明:

  • sequences(list[Tensor]):变长序列的列表。
  • batch_frist(bool,optional):如果为True,output形状为B × T × ∗ ,否则为T × B × ∗ ,默认情况为False。其中B BB为批次大小,T TT为填充后每个序列的长度。
  • padding_value(float,optional):填充元素的值。默认值:0。

输出:

如果 batch_first 是 False,张量的形状为T × B × ∗ 。否则,张量的形状为B × T × ∗ 。
举个栗子:

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence
content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
p1 = pad_sequence(DATA, batch_first=True)
print(p1)

在这里插入图片描述

pack_padded_sequence

压紧(pack)一个包含可变长度的填充序列的张量,在使用pad_sequence函数进行填充的时候,产生了冗余,因此需要对其进行pack。
参数说明:

  • input(Tensor):一批量填充后的可变长度的序列。
  • lenghts(Tensor or list(int)):每个批次元素的序列长度列表。如果输入为张量形式则必须在CPU上,不能在GPU上。
  • batch_first(bool,optional):如果为True,则输入的形状为B × T × ∗,我一般将其设置为True
  • enforce_sorted(bool,optional):如果为True,则参数lenghts为按长度递减排序的序列,这样的话输入的input也需要进行排序。我一般将其设置为False。如果为False输入将被无条件地排序。
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence

content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
print(content, DATA)
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
p2 = pack_padded_sequence(p1, [7, 5, 4], batch_first=True, enforce_sorted=False)
print(p2)

在这里插入图片描述

函数对返回的结果进行填充以恢复为原来的形状。
参数说明:

  • sequence(PackedSequence):需要填充的数据。
  • batch_first(bool,optional):如果为True,输出形状为B × T × ∗ B \times T \times *B×T×∗。
  • padding_value(float,optional):填充元素的值。
  • total_lenght(int,optional):如果不是无,输出将被填充成total_lenght。

输出:

包含填充序列的张量的元组,以及包含批次中每个序列的长度列表的张量。

from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence, pack_sequence, pad_packed_sequence

content = [[12, 12, 11, 1, 21, 7, 7], [12, 12, 11, 1, 21], [12, 12, 11, 21]]
DATA = list(map(lambda x: torch.tensor(x), content))
print(content, DATA)
p1 = pad_sequence(DATA, batch_first=True)
print(p1)
p2 = pack_padded_sequence(p1, [7, 5, 4], batch_first=True, enforce_sorted=False)
print(p2)
p3 = pad_packed_sequence(p2, batch_first=True)
print(p3)

在这里插入图片描述

pack_sequence

sequences (list[Tensor]): A list of sequences of decreasing length.enforce_sorted (bool, optional): if True, checks that the input contains sequences sorted by length in a decreasing order. If False, this condition is not checked. Default: True.

from torch.nn.utils.rnn import pack_sequence
import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5])
c = torch.tensor([6])
print(pack_sequence([a, b, c], True))

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_43788986/article/details/127618363
今日推荐