Tensor.masked_select()
方法详解
torch.Tensor.masked_select()
是 PyTorch 中用于从张量中 筛选出掩码为 True
的元素 的方法。它是对张量进行 条件筛选/提取 的常用工具,返回的是一个 1D 的新张量。
1. 函数原型
Tensor.masked_select(mask) → Tensor
参数 | 说明 |
---|---|
mask |
与原张量形状相同或可广播的布尔类型张量,True 表示保留该位置的值 |
返回 | 所有满足掩码条件的元素组成的 1D 张量 |
2. 功能说明
masked_select()
会遍历输入张量,对 mask == True
的位置提取值,返回这些元素构成的 一维张量。
3. 示例:基础用法
import torch
x = torch.tensor([[1, 2], [3, 4]])
mask = x > 2
selected = x.masked_select(mask)
print(selected) # tensor([3, 4])
解释:
x > 2
得到布尔掩码[[False, False], [True, True]]
- 只保留
3
和4
,返回的张量是一维的。
4. 广播兼容的掩码
x = torch.arange(12).reshape(3, 4)
mask = torch.tensor([True, False, True]).unsqueeze(1) # shape (3, 1)
# 广播成 (3, 4),选中第 0 和 2 行所有元素
selected = x.masked_select(mask)
print(selected) # tensor([0, 1, 2, 3, 8, 9, 10, 11])
5. 应用场景
应用 | 示例 |
---|---|
筛选满足条件的元素 | x.masked_select(x > 0) |
计算掩码条件下的均值 | x.masked_select(mask).mean() |
用于损失函数掩码 | loss.masked_select(valid_mask).mean() |
构建条件样本集合 | 从预测或输入中选出满足条件的部分 |
6. 与其他 mask 操作对比
方法 | 功能 | 输出形状 |
---|---|---|
masked_fill() |
条件替换为某个值 | 与输入相同 |
masked_scatter() |
条件替换为另一个张量的值 | 与输入相同 |
masked_select() |
条件筛选出值 | 1D 张量 |
where() |
条件选择或获取位置索引 | 可广播,多功能 |
7. 注意事项
mask
必须是torch.bool
类型;- 返回的张量是 一维的,无论原始张量是多少维;
- 返回值是 新张量,原张量不变。
8. 总结
特性 | 说明 |
---|---|
功能 | 从张量中按掩码提取元素 |
掩码类型 | BoolTensor ,可与输入广播 |
返回值 | 1D 张量(所有符合条件的值) |
常见用途 | 条件过滤、loss 掩码、统计分析等 |
masked_select()
是深度学习中处理变长输入、掩码样本、注意力机制筛选等场景的常用函数。