pytorch按维度取数据0917

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/jacke121/article/details/82745433

这个是三维的:

import numpy as np

import torch


x = torch.linspace(1,27,steps=27).view(3,3,3)
y= torch.ones(2,2)#.numpy()#.astype(np.uint8)#.type(torch.uint8)
#[第一维][第二维][第三维]
y=np.asarray([[0,1,2],[2,0,1],[1,0,2]])
# y[0,1]=0
x[y]=0

print(x)
import torch

a=torch.linspace(1,8,steps=8).view(2,2,2)
#这个是正确的:

b=torch.linspace(1,4,steps=4).view(2,2).view(4,1)
c=torch.linspace(0,1,steps=2).repeat(2).view(4,1)

d=torch.cat((b,c),1).view(2,2,2).type(torch.uint8)
# print(a)
print(d)
# d[0,0,0]=1
a[d]=0
print(a)

d维度相同,则相同位置,如果>0则会选中,否则,不会选中

tensor([[[1, 0],
         [2, 1]],

        [[3, 0],
         [4, 1]]], dtype=torch.uint8)
tensor([[[0., 2.],
         [0., 0.]],

        [[0., 6.],
         [0., 0.]]])

这里来个四维的:

import numpy as np

import torch
from numpy.core.defchararray import count

x = torch.linspace(1,81,steps=81).view(3,3,3,3)


g_index=np.asarray([0,1,2])
#[第一维][第二维][第三维]
y=np.asarray([[0,1,2],[2,0,1],[1,0,2]])

anch_ious= torch.linspace(1,9,steps=9).view(3,3)

values = torch.max(anch_ious, 1, keepdim=True)[0]

c = anch_ious - values
gj=np.asarray([2,0,1])
gi=np.asarray([1,0,2])
best_n = (c == 0)  # (c == 0).type(torch.uint8)
# print(best_n)
# x[best_n][g_index, gj] = 1

count=0
for i in (x[best_n, gj,gi]):
    count+=1
print(count)
x[best_n, gj,gi] = 1

# y[0,1]=0
# x[y]=0

print(x)

猜你喜欢

转载自blog.csdn.net/jacke121/article/details/82745433