python 近期用到的基础知识汇总(八)

1.pytorch 的scatter()函数

scatter() 和 scatter_() 的作用是一样的,只不过 scatter() 不会直接修改原来的 Tensor,而 scatter_() 会.

scatter(dim, index, src) 的参数有 3 个

  • dim:沿着哪个维度进行索引
  • index:用来 scatter 的元素索引
  • src:用来 scatter 的源元素,可以是一个标量或一个张量
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 2)
tensor([[2., 2., 2., 2., 2.],
        [0., 2., 0., 2., 0.],
        [2., 0., 2., 0., 2.]])

 上面的例子中第一个参数‘0’代表的是第0维度,例子中也就是行。第二个参数‘torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])’ 代表第元素索引,简单来说那些地方要进行更改元素值的操作,例如里面第一行的[0,1,2,0,0]代表分别要对它们同列的0,1,2,0,0处元素值进行更改。第三个参数‘2’代表最终更改值设定为2.

参考:https://www.cnblogs.com/dogecheng/p/11938009.html

扩展到3d矩阵

>>> torch.zeros(2, 3, 5).scatter_(1, torch.tensor([[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]],[[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]]), 2)
tensor([[[2., 2., 2., 2., 2.],
         [0., 2., 0., 2., 0.],
         [2., 0., 2., 0., 2.]],

        [[2., 2., 2., 2., 2.],
         [0., 2., 0., 2., 0.],
         [2., 0., 2., 0., 2.]]])

对比一下看第一个维度参数是确定在哪个维度变换。(一般语义分割掩码转变成one-hot标签时用到)

猜你喜欢

转载自blog.csdn.net/qq_36401512/article/details/112633862