1.公式ドキュメント:
torch.nn —PyTorch1.11.0のドキュメント
2.アプリケーションの背景:
pytorchを使用してデータを処理する場合、通常、複数のサンプルシーケンスがバッチの形式で同時に処理され、各バッチのサンプルシーケンスの長さが等しくないため、rnnはそれらを処理できません。したがって、通常の方法では、最初に、同じ長さの形式で最長のシーケンスに従って各バッチをパディングします。
ただし、パディング操作は問題を引き起こします。つまり、パディングされたほとんどのシーケンスでは、rnnが多くの役に立たない文字でそれを表すことになります。シーケンスが最後の有用な文字の後に出力できることを願っています。ベクトル表現、多くのパディング文字の後ではありません。
このとき、パック操作が始まります。パディング後に可変長シーケンスを圧縮し、圧縮後はパディング文字0を含まないことがわかります。具体的な操作は次のとおりです。
- 最初のステップである、パディング後の入力シーケンスは、最初にnn.utils.rnn.pack_padded_sequenceを通過します。これにより、PackedSequenceタイプのオブジェクトが取得され、RNNに直接渡すことができます(RNNのソースコードのフォワード関数が判断されます)入力がPackedSequenceであるかどうか、インスタンスは別のアクションを実行します。その場合、出力はそのタイプです。) ;
- 2番目のステップでは、通常、取得したPackedSequenceタイプのオブジェクトがRNNに直接渡され、このタイプの出力も取得されます。
- 3番目のステップは、nn.utils.rnn.pad_packed_sequenceを実行することです。つまり、RNNの後に出力を再パディングして、各バッチで同じ長さの通常のシーケンスを取得します。
3.機能の詳細:
3.1 nn.utils.rnn.pack_padded_sequence
torch.nn.utils.rnn.pack_padded_sequence —PyTorch1.11.0のドキュメント
3.2 nn.utils.rnn.pad_packed_sequence
torch.nn.utils.rnn.pad_packed_sequence —PyTorch1.11.0のドキュメント
4.コード例:
4.1使用する場合:
import torch
import torch.nn as nn
gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)
input = torch.tensor([[1,2,3,4,5],
[1,2,3,4,0],
[1,2,3,0,0],
[1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)
print(output)
4.2使用しない場合:
import torch
import torch.nn as nn
gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)
input = torch.tensor([[1,2,3,4,5],
[1,2,3,4,0],
[1,2,3,0,0],
[1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
# input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
# output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)
print(output)
5.いくつかのパラメーターに注意してください。
5.1 batch_first
RNNを含めると、パラメーターのデフォルトはFalseになります。つまり、通常の入力とは異なり、入力の最初の次元がバッチにならないようにします。これまでの入力は(batch_size、seq_len、embedding_dim)なので、次のようになります。注意する必要があるか、データ入力。このパラメータをTrueに設定してください。
5.2enforce_sorted
パラメータのデフォルトはTrueです。これは、デフォルトバッチの各シーケンスが長さの降順で配置されていることを意味します。したがって、ソートされていない場合はFalseに変更されることに注意してください。