pytorch中nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence

1.公式ドキュメント:

torch.nn —PyTorch1.11.0のドキュメント

 

2.アプリケーションの背景:

pytorchを使用してデータを処理する場合、通常、複数のサンプルシーケンスがバッチの形式で同時に処理され、各バッチのサンプルシーケンスの長さが等しくないため、rnnはそれらを処理できません。したがって、通常の方法では、最初に、同じ長さの形式で最長のシーケンスに従って各バッチをパディングします。

ただし、パディング操作は問題を引き起こします。つまり、パディングされたほとんどのシーケンスでは、rnnが多くの役に立たない文字でそれを表すことになります。シーケンスが最後の有用な文字の後に出力できることを願っています。ベクトル表現、多くのパディング文字の後ではありません。

このとき、パック操作が始まります。パディング後に可変長シーケンスを圧縮し、圧縮後はパディング文字0を含まないことがわかります。具体的な操作は次のとおりです。

  • 最初のステップである、パディング後の入力シーケンスは、最初にnn.utils.rnn.pack_padded_sequenceを通過します。これにより、PackedSequenceタイプのオブジェクトが取得され、RNNに直接渡すことができます(RNNのソースコードのフォワード関数が判断されます)入力がPackedSequenceであるかどうか、インスタンスは別のアクションを実行します。その場合、出力はそのタイプです。) ;
  • 2番目のステップでは、通常、取得したPackedSequenceタイプのオブジェクトがRNNに直接渡され、このタイプの出力も取得されます。
  • 3番目のステップは、nn.utils.rnn.pad_packed_sequenceを実行することです。つまり、RNNの後に出力を再パディングして、各バッチで同じ長さの通常のシーケンスを取得します。

3.機能の詳細:

3.1 nn.utils.rnn.pack_padded_sequence

torch.nn.utils.rnn.pack_padded_sequence —PyTorch1.11.0のドキュメント

3.2 nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence —PyTorch1.11.0のドキュメント

4.コード例:

4.1使用する場合:

import torch
import torch.nn as nn

gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)

input = torch.tensor([[1,2,3,4,5],
                      [1,2,3,4,0],
                      [1,2,3,0,0],
                      [1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)

print(output)

 

4.2使用しない場合:

import torch
import torch.nn as nn

gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)

input = torch.tensor([[1,2,3,4,5],
                      [1,2,3,4,0],
                      [1,2,3,0,0],
                      [1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
# input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
# output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)

print(output)

 

5.いくつかのパラメーターに注意してください。

5.1 batch_first

RNNを含めると、パラメーターのデフォルトはFalseになります。つまり、通常の入力とは異なり、入力の最初の次元がバッチにならないようにします。これまでの入力は(batch_size、seq_len、embedding_dim)なので、次のようになります。注意する必要があるか、データ入力。このパラメータをTrueに設定してください。

5.2enforce_sorted

パラメータのデフォルトはTrueです。これは、デフォルトバッチの各シーケンスが長さの降順で配置されていることを意味します。したがって、ソートされていない場合はFalseに変更されることに注意してください。

 部分参照:https ://www.cnblogs.com/yuqinyuqin/p/14100967.html 

おすすめ

転載: blog.csdn.net/m0_46483236/article/details/124136437