pytorch批量替换张量中的元素值(torch.where)

pytorch批量替换张量中的元素值(torch.where)

代码

import torch

x1 = [[[[1, 2], [3, 4]]]]
x1 = torch.tensor(x1)

# 查询x1中的是否存在等于1的元素,如果存在,则使用382替换1
# 在实际应用中,382可以设置为任何值。
b = torch.where(x1 == 1, 382, x1)
print(b)


运行结果

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_34848334/article/details/129732106