pytorch中 nn.utils.rnn.pack_padded_sequence和nn.utils.rnn.pad_packed_sequence

1. Documentación oficial:

antorcha.nn — documentación de PyTorch - 1.11.0

 

2. Antecedentes de la aplicación:

Cuando se usa pytorch para procesar datos, generalmente se procesan varias secuencias de muestra al mismo tiempo en forma de lotes, y las secuencias de muestra en cada lote tienen una longitud desigual, lo que hace que rnn no pueda procesarlas. Por lo tanto, la práctica habitual es rellenar primero cada lote de acuerdo con la secuencia más larga en forma de igual longitud.

Pero la operación de relleno traerá un problema, es decir, para la mayoría de las secuencias que han sido rellenadas, hará que rnn la represente con muchos caracteres inútiles. Esperamos que la secuencia pueda salir después del último carácter útil. Representación vectorial , no después de muchos caracteres de relleno.

En este momento, entra en juego la operación de paquete, se puede entender que comprime una secuencia de longitud variable después del relleno y no contiene el carácter de relleno 0 después de la compresión. La operación específica es:

  • El primer paso, la secuencia de entrada después del relleno pasa primero a través de nn.utils.rnn.pack_padded_sequence, que obtendrá un objeto de tipo PackedSequence, que se puede pasar directamente a RNN (la función de reenvío en el código fuente de RNN aparece para juzgar si la entrada es PackedSequence o no instancia, que a su vez toma una acción diferente, y si es así, la salida es de ese tipo.) ;
  • En el segundo paso, el objeto obtenido de tipo PackedSequence normalmente se pasa directamente a la RNN, y también se obtiene la salida de este tipo ;
  • El tercer paso es pasar por nn.utils.rnn.pad_packed_sequence, es decir, volver a rellenar la salida después de RNN para obtener una secuencia normal de igual longitud para cada lote.

3. Detalles de la función:

3.1 nn.utils.rnn.pack_padded_secuencia

torch.nn.utils.rnn.pack_padded_sequence — documentación de PyTorch - 1.11.0

3.2 nn.utils.rnn.pad_packed_sequence

torch.nn.utils.rnn.pad_packed_sequence — documentación de PyTorch - 1.11.0

4. Ejemplo de código:

4.1 Al usar:

import torch
import torch.nn as nn

gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)

input = torch.tensor([[1,2,3,4,5],
                      [1,2,3,4,0],
                      [1,2,3,0,0],
                      [1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)

print(output)

 

4.2 Cuando no esté en uso:

import torch
import torch.nn as nn

gru = nn.GRU(input_size=1, hidden_size=1, batch_first=True)

input = torch.tensor([[1,2,3,4,5],
                      [1,2,3,4,0],
                      [1,2,3,0,0],
                      [1,2,0,0,0]]).unsqueeze(2)
input_lengths = torch.tensor([5,4,3,2])
# input = nn.utils.rnn.pack_padded_sequence(input, input_lengths, batch_first=True, enforce_sorted=False)
print(type(input))
print(input)
output, hidden = gru(input.float())
# output, _ = torch.nn.utils.rnn.pad_packed_sequence(sequence=output, batch_first=True)

print(output)

 

5. Tenga en cuenta algunos parámetros:

5.1 lote_primero

Incluyendo RNN, el parámetro predeterminado es Falso, es decir, fomenta que la primera dimensión de la entrada no sea por lotes, lo que es contrario a nuestra entrada normal. La entrada a la que estamos acostumbrados es (batch_size, seq_len, embedding_dim), por lo que necesita prestar atención, o la entrada de datos, o establezca este parámetro en Verdadero.

5.2 enforce_sorted

El parámetro predeterminado es True, lo que significa que cada secuencia en el lote predeterminado se ha organizado en orden descendente de longitud, por lo que debe tenerse en cuenta que si no se ordena, se cambiará a False.

 Referencia parcial: https://www.cnblogs.com/yuqinyuqin/p/14100967.html 

Supongo que te gusta

Origin blog.csdn.net/m0_46483236/article/details/124136437
Recomendado
Clasificación