pytorch --- tensor.squeeze (DIM) 和 unsqueeze (DIM)

tensor.squeeze (DIM)

Acción : Si la dimensión tenue valor especificado 1, entonces la dimensión de borrado, si el valor de la dimensión especificada no es 1, se devuelve el original tensor

Ejemplos:

x = torch.rand(2,1,3)
print(x)
print(x.squeeze(1))
print(x.squeeze(2))

salida:

tensor([[[0.7031, 0.7173, 0.0606]],

        [[0.6884, 0.4072, 0.0516]]])
        
tensor([[0.7031, 0.7173, 0.0606],
        [0.6884, 0.4072, 0.0516]])

tensor([[[0.7031, 0.7173, 0.0606]],

        [[0.6884, 0.4072, 0.0516]]])

Como se muestra en los resultados de: x.shape = [2, 1, 3], un valor de una primera dimensión, de modo x.squeeze (dim = 1) de salida se retire la primera dimensión, la forma de salida = [2, 3], el segundo valor de dimensión no es 1, por lo x.squeeze (dim = 2) constante de salida tensor de la forma de

tensor.unsqueeze (DIM)

Esta función es principalmente para ampliar dimensiones de datos. Posición además de una dimensión especificada es una dimensión, tal como los datos originales de un Tres líneas (3), además de la posición unidimensional 0 se convierte en la línea de tres (3). Otra forma es b = torch.squeeze (tensor, dim) es para especificar la posición del tensor en dim adición de una dimensión a una dimensión de 1

Ejemplos:

x = torch.rand(2,3)
print(x)
print("x.shape:", x.shape)
y = torch.unsqueeze(x, 1)
print(y)
print("y.shape:", y.shape)
z = x.unsqueeze(2)
print(z)
print("z.shape:", z.shape)

salida:

tensor([[0.1255, 0.7249, 0.5253],
        [0.9247, 0.4592, 0.3944]])
x.shape: torch.Size([2, 3])


tensor([[[0.1255, 0.7249, 0.5253]],

        [[0.9247, 0.4592, 0.3944]]])
y.shape: torch.Size([2, 1, 3])


tensor([[[0.1255],
         [0.7249],
         [0.5253]],

        [[0.9247],
         [0.4592],
         [0.3944]]])
z.shape: torch.Size([2, 3, 1])
[Finished in 2.6s]
Publicado 33 artículos originales · ganado elogios 1 · vistas 2611

Supongo que te gusta

Origin blog.csdn.net/orangerfun/article/details/104012564
Recomendado
Clasificación