pytorch中mask_select()的用法

pytorch中mask_select()的用法

import torch

a =torch.Tensor([1,2,4,4,5])
print(torch.masked_select(a, a<4))

1.a<4取出的是布尔值索引(掩码)[1,1,0,0,0,]
2.torch.masked_select(a, a<4):根据a<4的非0掩码从a中取值
print(torch.masked_select(a, a<4)):
在这里插入图片描述

发布了18 篇原创文章 · 获赞 2 · 访问量 340

猜你喜欢

转载自blog.csdn.net/weixin_44928646/article/details/104629013