torch之where&&gather

where&&gather

##1、where

cond=torch.rand(3,3)
tensor([[0.3996, 0.0690, 0.9096],
[0.5240, 0.2011, 0.9113],
[0.6207, 0.0227, 0.9679]])

a=torch.ones([3,3])
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])

b=torch.zeros([3,3])
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]])

c=torch.where(cond>0.5,a,b)–>a,b,cond均是同一个size,语义:cond[i]>0.5,a[i]->c[i],否则b[i]->c[i]======>等价于
在for i in range(3):(cpu中实现,而torch.where可在gpu中实现,加快速度)
for j in range(3):
if cond >0.5:
a[i]->c[i]
else:
b[i]->c[i]
tensor([[0., 0., 1.],
[1., 0., 1.],
[1., 0., 1.]])

##2、Gather
import torch
a=[‘鼠’,‘牛’,‘虎’,‘兔’,‘龙’,‘蛇’,‘马’,‘羊’,‘猴’,‘鸡’,‘狗’,‘猪’]
a_num=torch.arange(0,12)
cond=torch.rand(4,12)–>看做4组12生肖分类置信度
tensor([[0.8059, 0.6237, 0.0252, 0.6578, 0.8973, 0.1721, 0.0973, 0.5686, 0.2628,0.1727, 0.5576, 0.2738],
[0.3605, 0.7545, 0.6452, 0.3503, 0.9368, 0.7807, 0.1228, 0.5853, 0.4578,0.8317, 0.0125, 0.4119],
[0.6783, 0.7420, 0.5800, 0.9243, 0.7156, 0.3565, 0.2481, 0.4220, 0.4311, 0.9624, 0.2303, 0.4836],
[0.4644, 0.9783, 0.1354, 0.2682, 0.2635, 0.4039, 0.9332, 0.6616, 0.4908, 0.3773, 0.2385, 0.8812]])
label=cond.topk(k=3,dim=1)–>取置信度前3的索引
label=label[1]
result=torch.gather(a_num.expand(4,12),dim=1,index=label)–>在label的标签值作为索引位置返回a_num对应位置的值(查表),与where一样可利用gpu加速
result=tensor([[ 4, 0, 3], [ 4, 9, 5], [ 9, 3, 1], [ 1, 6, 11]])
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_40859802/article/details/103765913