pytorch:where & gather

import torch

where

  • where(condition,a,b) 满足条件,返回a里面的对应元素,不满足条件,返回b里面对应的元素
a = torch.rand(2,3)
b = torch.rand(2,3)
torch.where(a>b,a,b)
tensor([[0.2254, 0.7619, 0.9761],
        [0.7787, 0.4238, 0.8476]])

gather

  • 设X的size是P·Q·M

  • Z=X.gather(dim=0,index=Y) Z是一个size和Y一样的tensor

  • 对Y的要求是:Y的size=N·Q·M N是任意正整数,对构成Y的元素的取值范围是[0,P-1]

  • 现在有一个tensor W,W的size是Q·M,Wqm为X1qm,X2qm…Xnqmzip在一起的list,即Wqm=(X1qm,X2qm…Xnqm)

    • step1:按照dim的取值对X中的数据进行zip,形成W
    • step2:形成index矩阵Y
    • step3:依照Y,从W中的每个小zip中取值
      在这里插入图片描述

例1:

test=torch.randint(0,11,[2,3,4])
test
tensor([[[ 6,  2, 10,  1],
         [ 6,  0,  0,  9],
         [ 2,  7,  6,  5]],

        [[ 1,  4,  8,  5],
         [ 0, 10,  0,  0],
         [ 4,  6,  8,  9]]])
index = torch.randint(0,2,[3,3,4])
index
tensor([[[0, 0, 1, 0],
         [0, 1, 1, 1],
         [1, 0, 0, 1]],

        [[0, 1, 0, 1],
         [0, 0, 0, 0],
         [0, 0, 0, 1]],

        [[0, 1, 0, 0],
         [1, 1, 0, 0],
         [1, 0, 1, 1]]])
test.gather(dim=0,index=index)
tensor([[[ 6,  2,  8,  1],
         [ 6, 10,  0,  0],
         [ 4,  7,  6,  9]],

        [[ 6,  4, 10,  5],
         [ 6,  0,  0,  9],
         [ 2,  7,  6,  9]],

        [[ 6,  4, 10,  1],
         [ 0, 10,  0,  9],
         [ 4,  7,  8,  9]]])

例2

a = torch.randint(1,11,[2,3,4])
a
tensor([[[ 5, 10,  5,  5],
         [10,  8,  4,  2],
         [ 6,  5,  1,  9]],

        [[ 1,  7,  3,  3],
         [ 8,  5,  8,  2],
         [ 7,  5, 10,  1]]])
index = torch.randint(0,3,[2,2,4])
index
tensor([[[0, 0, 0, 0],
         [1, 0, 1, 1]],

        [[1, 2, 2, 0],
         [2, 2, 2, 0]]])
a.gather(dim=1,index=index)
#这里dim = 1,所以把(5,10,6)zip在一起,(10,8,5)zip在一起...(3,2,1)zip在一起,然后按顺序把这些小zip排成2*4的tensor
#然后拿着index在这个tensor中选
tensor([[[ 5, 10,  5,  5],
         [10, 10,  4,  2]],

        [[ 8,  5, 10,  3],
         [ 7,  5, 10,  3]]])

例3

a = torch.randint(1,11,[2,3,3,2])
a
tensor([[[[10,  7],
          [ 9,  7],
          [ 9,  1]],

         [[10,  5],
          [ 4,  5],
          [ 3,  7]],

         [[10,  8],
          [ 4, 10],
          [ 6,  1]]],


        [[[ 1,  4],
          [ 3,  2],
          [ 1,  7]],

         [[ 5,  2],
          [ 1,  4],
          [ 9, 10]],

         [[ 2,  5],
          [ 5,  5],
          [ 7,  4]]]])
index=torch.randint(0,2,[2,3,3,1])
index
tensor([[[[0],
          [1],
          [0]],

         [[0],
          [0],
          [1]],

         [[1],
          [1],
          [1]]],


        [[[1],
          [0],
          [1]],

         [[1],
          [0],
          [0]],

         [[0],
          [0],
          [0]]]])
a.gather(dim=3,index=index)
#dim=3 所以把(10,7)zip在一起,(9,7)zip在一起,(9,1)zip在一起,...(7,4)zip在一起
#...
tensor([[[[10],
          [ 7],
          [ 9]],

         [[10],
          [ 4],
          [ 7]],

         [[ 8],
          [10],
          [ 1]]],


        [[[ 4],
          [ 3],
          [ 7]],

         [[ 2],
          [ 1],
          [ 9]],

         [[ 2],
          [ 5],
          [ 7]]]])

发布了43 篇原创文章 · 获赞 1 · 访问量 758

猜你喜欢

转载自blog.csdn.net/weixin_41391619/article/details/104699555