torch中的替换操作

目录

1.1通过比较操作得到布尔矩阵

1.2布尔矩阵作为索引

1.3布尔矩阵的强转


 1.通过比较替换

1.1通过比较操作得到布尔矩阵

a = torch.rand((5, 6), dtype=torch.float32)
print(a)
print(a > 0.5)
-----------------------------------------------------------------------------------------
tensor([[0.7172, 0.0393, 0.0810, 0.7734, 0.0044, 0.3191],
        [0.3387, 0.2647, 0.3805, 0.0339, 0.5828, 0.2851],
        [0.9486, 0.1131, 0.4608, 0.9621, 0.2813, 0.1919],
        [0.8794, 0.8339, 0.7018, 0.3440, 0.8190, 0.2513],
        [0.4571, 0.0413, 0.0869, 0.0827, 0.2947, 0.1362]])
tensor([[ True, False, False,  True, False, False],
        [False, False, False, False,  True, False],
        [ True, False, False,  True, False, False],
        [ True,  True,  True, False,  True, False],
        [False, False, False, False, False, False]])

1.2布尔矩阵作为索引

当布尔矩阵出现在下标位置时,充当索引的角色,可以进行赋值、数学运算等操作。

a[a > 0.5] = 1

对应于True位置的数值替换为1.0 

a = torch.rand((5, 6), dtype=torch.float32)
index = a > 0.5
print(a)
print(a > 0.5)
a[a > 0.5] = 1
print(a)

-------------------------------------------------------------------------------------------
tensor([[0.7869, 0.9450, 0.6493, 0.3837, 0.0440, 0.7851],
        [0.0643, 0.2934, 0.4264, 0.5007, 0.4667, 0.8721],
        [0.7292, 0.8509, 0.0858, 0.7933, 0.3860, 0.4047],
        [0.6582, 0.5485, 0.4721, 0.8394, 0.6033, 0.9711],
        [0.4886, 0.2380, 0.7981, 0.9216, 0.4049, 0.7381]])
tensor([[ True,  True,  True, False, False,  True],
        [False, False, False,  True, False,  True],
        [ True,  True, False,  True, False, False],
        [ True,  True, False,  True,  True,  True],
        [False, False,  True,  True, False,  True]])
tensor([[1.0000, 1.0000, 1.0000, 0.3837, 0.0440, 1.0000],
        [0.0643, 0.2934, 0.4264, 1.0000, 0.4667, 1.0000],
        [1.0000, 1.0000, 0.0858, 1.0000, 0.3860, 0.4047],
        [1.0000, 1.0000, 0.4721, 1.0000, 1.0000, 1.0000],
        [0.4886, 0.2380, 1.0000, 1.0000, 0.4049, 1.0000]])

1.3布尔矩阵的强转

(decoder_out > 0.05).type(torch.int16)

通过强转也可以将布尔矩阵转为数值矩阵,然后在做数值运算。

a = torch.rand((5, 6), dtype=torch.float32)
print(a)
print((a > 0.5).type(torch.int16))

--------------------------------------------------------------------------------------
tensor([[0.6869, 0.0706, 0.1450, 0.2567, 0.9260, 0.2848],
        [0.8413, 0.1677, 0.2624, 0.7488, 0.5229, 0.0857],
        [0.9411, 0.9034, 0.0416, 0.4502, 0.9404, 0.3534],
        [0.9312, 0.5303, 0.3516, 0.2819, 0.3869, 0.4441],
        [0.0093, 0.5280, 0.7881, 0.1460, 0.9870, 0.1433]])
tensor([[1, 0, 0, 0, 1, 0],
        [1, 0, 0, 1, 1, 0],
        [1, 1, 0, 0, 1, 0],
        [1, 1, 0, 0, 0, 0],
        [0, 1, 1, 0, 1, 0]], dtype=torch.int16)

2.通过块索引替换

类似于python的列表索引切片,tensor也可以这样做。

a = torch.rand((5, 6), dtype=torch.float32)
print(a)
a[0:2, 0:2] = 1
print(a)

-------------------------------------------------------------------------------------------
tensor([[0.2457, 0.9803, 0.6518, 0.7468, 0.8744, 0.9395],
        [0.5727, 0.4081, 0.2449, 0.3435, 0.5808, 0.9078],
        [0.6032, 0.8980, 0.8591, 0.9664, 0.1635, 0.8161],
        [0.0347, 0.3219, 0.0546, 0.2887, 0.6355, 0.8978],
        [0.0848, 0.8553, 0.6150, 0.6221, 0.9916, 0.0254]])
tensor([[1.0000, 1.0000, 0.6518, 0.7468, 0.8744, 0.9395],
        [1.0000, 1.0000, 0.2449, 0.3435, 0.5808, 0.9078],
        [0.6032, 0.8980, 0.8591, 0.9664, 0.1635, 0.8161],
        [0.0347, 0.3219, 0.0546, 0.2887, 0.6355, 0.8978],
        [0.0848, 0.8553, 0.6150, 0.6221, 0.9916, 0.0254]])

3.clip操作

pytorch官方的 api 输入一个tensor input,设定最小值min 最大值max,超过最大值的数值会被设置为最大值,小于最小值的数值被设置为最小值。

torch.clip(input, min=None, max=None, *, out=None) → Tensor

相同的还有torch.clamp API,作用与clip相同。

torch.clamp(inputmin=Nonemax=None*out=None) → Tensor 

a = torch.randn(4)
print(a)
torch.clamp(a, min=-0.5, max=0.5)
print(a)

-------------------------------------------------------------------------------------------

tensor([-1.7120,  0.1734, -0.0478, -0.0922])

tensor([-0.5000,  0.1734, -0.0478, -0.0922])

猜你喜欢

转载自blog.csdn.net/qq_55796594/article/details/125061049