Operación de datos de máscara de antorcha Función Masked_select() y función Masked_scatter()

selección_enmascarada()

Seleccionar datos de máscara
import torch
x = torch.tensor([[-2,-1,0],[-1,0,1],[0,1,2]])
x
tensor([[-2, -1,  0],
        [-1,  0,  1],
        [ 0,  1,  2]])
mask = x>0
mask
tensor([[False, False, False],
        [False, False,  True],
        [False,  True,  True]])
torch.masked_select(x,mask)
tensor([1, 1, 2])

dispersión_enmascarada()

Reemplazar datos de máscara
variable = torch.tensor([[100,100,100],[100,100,100],[100,100,100]])
variable                                                            
tensor([[100, 100, 100],
        [100, 100, 100],
        [100, 100, 100]])
x.masked_scatter_(mask,variable)                                    
tensor([[ -2,  -1,   0],
        [ -1,   0, 100],
        [  0, 100, 100]])
de otra manera
variable_new = torch.tensor([100,100,100])                          
variable_new                              
tensor([100, 100, 100])
x.masked_scatter_(mask,variable_new)      
tensor([[ -2,  -1,   0],
        [ -1,   0, 100],
        [  0, 100, 100]])
Si la longitud variable es menor que la longitud de la máscara, no se puede reemplazar
variable_current = torch.tensor([100,100])        
variable_current                          
tensor([100, 100])
x.masked_scatter_(mask,variable_current)  
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: Number of elements of source < number of ones in mask

Supongo que te gusta

Origin blog.csdn.net/weixin_46398647/article/details/126857664
Recomendado
Clasificación