pytorch mask_filled用法

 #将 mask必须是一个 ByteTensor 而且shape必须和 a一样 并且元素只能是 0或者1 ,是将 mask中为1的 元素所在的索引,在a中相同的的索引处替换为 value  ,mask value必须同为tensor 

a=torch.tensor([1,0,2,3])
    # a.masked_fill(mask = torch.ByteTensor([1,1,0,0]), value=torch.tensor(-1e9))

# tensor([-1.0000e+09, -1.0000e+09,  2.0000e+00,  3.0000e+00])

猜你喜欢

转载自blog.csdn.net/candy134834/article/details/84594754
今日推荐