pytorch-高阶操作

高阶操作

where

torch.where(condition,x,y)
(条件,A矩阵,B矩阵),符合条件从A对应位置选数,不符合就从B对应位置选数
条件选择函数

a=torch.full([2,2],0)
b=torch.full([2,2],1)
a,b
'''
(tensor([[0, 0],
         [0, 0]]),
 tensor([[1, 1],
         [1, 1]]))
'''
#条件
cond=torch.rand(2,2)
cond
'''
tensor([[0.6961, 0.8969],
        [0.2795, 0.9759]])
'''

torch.where(cond>0.5,a,b)
#如果cond的值>0.5就选取a对应位置的数,不是就选取b中对应位置的数
'''
tensor([[0, 0],
        [1, 0]])
'''

gather

torch.gather(input,dim,index,out=None)-Tensor
input,输入数据
dim,查看维度
index,查看索引
out=None

prob=torch.randn(4,10)
'''
tensor([[ 0.6805, -0.4651,  0.6448,  0.6679, -0.5646,  2.3565,  0.9479, -0.0406,
         -0.4645,  1.3624],
        [ 0.8647, -0.5109,  0.5100,  0.6534, -0.8373, -1.8661, -0.8300, -0.0230,
         -0.2076,  0.6472],
        [ 0.9843,  1.0484,  0.1264, -1.2768,  0.7247,  0.9827,  1.1230,  0.9566,
          0.4962, -0.9180],
        [ 1.3375,  0.7297, -0.8324,  0.5294, -1.7625,  0.7328,  0.9702, -0.0741,
          2.6688,  0.1584]])
'''

#得到按第一维排序的top3的数的大小以及位置,输出形式与原来的数一样
idx=prob.topk(dim=1,k=3)
idx
'''
torch.return_types.topk(
values=tensor([[2.3565, 1.3624, 0.9479],
        [0.8647, 0.6534, 0.6472],
        [1.1230, 1.0484, 0.9843],
        [2.6688, 1.3375, 0.9702]]),
indices=tensor([[5, 9, 6],
        [0, 3, 9],
        [6, 1, 0],
        [8, 0, 6]]))
'''

label=torch.arange(10)+100
label
'''
tensor([100, 101, 102, 103, 104, 105, 106, 107, 108, 109])
'''
idx=idx[1]#只得到第二部分即对应数的位置
idx
'''
tensor([[5, 9, 6],
        [0, 3, 9],
        [6, 1, 0],
        [8, 0, 6]])
        '''

idx.long()
'''tensor([[5, 9, 6],
        [0, 3, 9],
        [6, 1, 0],
        [8, 0, 6]])
'''
#按idx的下标在label.expand中查找对应的数
torch.gather(label.expand(4,10),dim=1,index=idx.long())
'''
tensor([[105, 109, 106],
        [100, 103, 109],
        [106, 101, 100],
        [108, 100, 106]])
'''

猜你喜欢

转载自blog.csdn.net/qq_43894221/article/details/127113813
今日推荐