Resumen de conocimientos básicos recientemente utilizados en Python (8)

1. La función scatter () de pytorch

 El papel de scatter ()  y  scatter_ () es el mismo, excepto que scatter () no modificará directamente el tensor original, mientras que scatter_ () sí.

Hay 3 parámetros para la dispersión (dim, index, src)

  • dim: a lo largo de qué dimensión indexar
  • índice: el índice del elemento utilizado para la dispersión
  • src: el elemento de origen utilizado para la dispersión, que puede ser un escalar o un tensor
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 2)
tensor([[2., 2., 2., 2., 2.],
        [0., 2., 0., 2., 0.],
        [2., 0., 2., 0., 2.]])

 En el ejemplo anterior, el primer parámetro '0' representa la dimensión 0, que es la fila del ejemplo. El segundo parámetro'torch.tensor ([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) 'representa el índice del primer elemento. En resumen, esos lugares deben Las operaciones de valor, por ejemplo, [0,1,2,0,0] en la primera fila de la primera fila significan cambiar los valores de los elementos en 0, 1, 2, 0, 0 en el misma columna. El tercer parámetro '2' significa que el valor de cambio final se establece en 2.

Referencia: https://www.cnblogs.com/dogecheng/p/11938009.html

Extender a matriz 3d

>>> torch.zeros(2, 3, 5).scatter_(1, torch.tensor([[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]],[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]]), 2)
tensor([[[2., 2., 2., 2., 2.],
         [0., 2., 0., 2., 0.],
         [2., 0., 2., 0., 2.]],

        [[2., 2., 2., 2., 2.],
         [0., 2., 0., 2., 0.],
         [2., 0., 2., 0., 2.]]])

Compárelo con el primer parámetro de dimensión para determinar qué dimensión transformar. (Generalmente se usa cuando la máscara de segmentación semántica se convierte en una etiqueta única)

Supongo que te gusta

Origin blog.csdn.net/qq_36401512/article/details/112633862
Recomendado
Clasificación