1. La función scatter () de pytorch
El papel de scatter () y scatter_ () es el mismo, excepto que scatter () no modificará directamente el tensor original, mientras que scatter_ () sí.
Hay 3 parámetros para la dispersión (dim, index, src)
- dim: a lo largo de qué dimensión indexar
- índice: el índice del elemento utilizado para la dispersión
- src: el elemento de origen utilizado para la dispersión, que puede ser un escalar o un tensor
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 2)
tensor([[2., 2., 2., 2., 2.],
[0., 2., 0., 2., 0.],
[2., 0., 2., 0., 2.]])
En el ejemplo anterior, el primer parámetro '0' representa la dimensión 0, que es la fila del ejemplo. El segundo parámetro'torch.tensor ([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]) 'representa el índice del primer elemento. En resumen, esos lugares deben Las operaciones de valor, por ejemplo, [0,1,2,0,0] en la primera fila de la primera fila significan cambiar los valores de los elementos en 0, 1, 2, 0, 0 en el misma columna. El tercer parámetro '2' significa que el valor de cambio final se establece en 2.
Referencia: https://www.cnblogs.com/dogecheng/p/11938009.html
Extender a matriz 3d
>>> torch.zeros(2, 3, 5).scatter_(1, torch.tensor([[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]],[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]]), 2)
tensor([[[2., 2., 2., 2., 2.],
[0., 2., 0., 2., 0.],
[2., 0., 2., 0., 2.]],
[[2., 2., 2., 2., 2.],
[0., 2., 0., 2., 0.],
[2., 0., 2., 0., 2.]]])
Compárelo con el primer parámetro de dimensión para determinar qué dimensión transformar. (Generalmente se usa cuando la máscara de segmentación semántica se convierte en una etiqueta única)