pytorch中 nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence

1. Official documentation:

torch.nn — PyTorch 1.11.0 documentation

 

2. Application background:

When using pytorch to process data, multiple sample sequences are generally processed at the same time in the form of batches, and the sample sequences in each batch are of unequal length, which makes rnn unable to process them. Therefore, the usual practice is to first padding each batch according to the longest sequence in the form of equal length.

But the padding operation will bring a problem, that is, for most sequences that have been padding, it will cause rnn to represent it with a lot of useless characters. We hope that the sequence can be output after the last useful character. vector representation, not after a lot of padding characters.

At this time, the pack operation comes into play. It can be understood that it compresses a variable-length sequence after padding, and does not contain the padding character 0 after compression. The specific operation is:

  • In the first step, the input sequence after padding first passes through nn.utils.rnn.pack_padded_sequence, which will get an object of PackedSequence type, which can be directly passed to RNN (the forward function in the source code of RNN comes up to judge whether the input is PackedSequence or not instance, which in turn takes a different action, and if so, the output is of that type.) ;
  • In the second step, the obtained object of PackedSequence type is normally directly passed to the RNN, and the output of this type is also obtained ;
  • The third step is to go through nn.utils.rnn.pad_packed_sequence, that is, re-padding the output after the RNN, and get a normal sequence of equal length for each batch.

3. Function details:

3.1 nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence — PyTorch 1.11.0 documentation

3.2 nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence — PyTorch 1.11.0 documentation

4. Code example:

4.1 When using:

import torch
import torch.nn as nn

gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)

input = torch.tensor([[1,2,3,4,5],
                      [1,2,3,4,0],
                      [1,2,3,0,0],
                      [1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)

print(output)

 

4.2 When not in use:

import torch
import torch.nn as nn

gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)

input = torch.tensor([[1,2,3,4,5],
                      [1,2,3,4,0],
                      [1,2,3,0,0],
                      [1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
# input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
# output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)

print(output)

 

5. Note a few parameters:

5.1 batch_first

Including RNN, the parameter defaults to False, that is, it encourages the first dimension of the input to not be batch, which is contrary to our regular input. The input we are used to is (batch_size, seq_len, embedding_dim), so we need to pay attention, or the data input, Either set this parameter to True.

5.2 enforce_sorted

The parameter defaults to True, which means that each sequence in the default batch has been arranged in descending order of length, so it should be noted that if it is not sorted, it will be changed to False.

 Partial reference: https://www.cnblogs.com/yuqinyuqin/p/14100967.html 

Guess you like

Origin blog.csdn.net/m0_46483236/article/details/124136437