【Pytorch】高阶操作

1. where 函数

  • 源码定义:
def where(condition: Tensor) -> Union[Tuple[Tensor, ...], List[Tensor]]: ...
  • torch.where 函数功能如下:
torch.where(condition , x, y)

某元素满足条件使用 x Tensor 来填充,不满足条件使用 y Tensor 来填充,其中 x 和 y 应当与原 Tensor 维度及 size 相同

cond =  torch.rand(2, 2)
print(cond)

a = torch.zeros(2, 2)
b = torch.ones(2, 2)

c = torch.where(cond>0.5, a, b)
print(c)

输出:

tensor([[0.4716, 0.8124],
        [0.3771, 0.1771]])
tensor([[1., 0.],
        [1., 1.]])

2. gather 函数

def gather(input: Tensor, dim, index: Tensor) -> Tensor: ...
  • 根据 index Tensor 中的值作为 input Tensor 中的索引,生成一个新的 Tensor

例子

prob = torch.randn(4, 10)	
idx = prob.topk(dim=1, k=3)	# 选出最有可能的 3 种
print(idx)
'''
values=tensor([[1.2655, 0.5347, 0.4686],
                [1.9430, 1.1472, 1.1349],
                [1.2370, 0.8487, 0.7665],
                [2.0423, 2.0380, 1.0663]]),
indices=tensor([[8, 6, 7],
                [8, 9, 5],
                [2, 5, 7],
                [5, 2, 3]]))
'''
idx = idx[1]	# 每一个照片最有可能的 3 种情况
print(idx)
'''
tensor([[8, 6, 7],
        [8, 9, 5],
        [2, 5, 7],
        [5, 2, 3]])
'''
label = torch.arange(10) * 100
print(label)
'''
tensor([  0, 100, 200, 300, 400, 500, 600, 700, 800, 900])
'''
ret = torch.gather(label.expand(4, 10), dim=1, index=idx.long())
print(ret)
'''
tensor([[800, 600, 700],
        [800, 900, 500],
        [200, 500, 700],
        [500, 200, 300]])
'''

图解

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/weixin_45437022/article/details/114296505
今日推荐