【PyTorch】Tensor.masked_select() 方法:筛选出掩码为 True 的元素

Tensor.masked_select() 方法详解

torch.Tensor.masked_select() 是 PyTorch 中用于从张量中 筛选出掩码为 True 的元素 的方法。它是对张量进行 条件筛选/提取 的常用工具,返回的是一个 1D 的新张量


1. 函数原型

Tensor.masked_select(mask) → Tensor
参数 说明
mask 与原张量形状相同或可广播的布尔类型张量,True 表示保留该位置的值
返回 所有满足掩码条件的元素组成的 1D 张量

2. 功能说明

masked_select() 会遍历输入张量,对 mask == True 的位置提取值,返回这些元素构成的 一维张量


3. 示例:基础用法

import torch

x = torch.tensor([[1, 2], [3, 4]])
mask = x > 2

selected = x.masked_select(mask)
print(selected)  # tensor([3, 4])

解释:

  • x > 2 得到布尔掩码 [[False, False], [True, True]]
  • 只保留 34,返回的张量是一维的。

4. 广播兼容的掩码

x = torch.arange(12).reshape(3, 4)
mask = torch.tensor([True, False, True]).unsqueeze(1)  # shape (3, 1)

# 广播成 (3, 4),选中第 0 和 2 行所有元素
selected = x.masked_select(mask)
print(selected)  # tensor([0, 1, 2, 3, 8, 9, 10, 11])

5. 应用场景

应用 示例
筛选满足条件的元素 x.masked_select(x > 0)
计算掩码条件下的均值 x.masked_select(mask).mean()
用于损失函数掩码 loss.masked_select(valid_mask).mean()
构建条件样本集合 从预测或输入中选出满足条件的部分

6. 与其他 mask 操作对比

方法 功能 输出形状
masked_fill() 条件替换为某个值 与输入相同
masked_scatter() 条件替换为另一个张量的值 与输入相同
masked_select() 条件筛选出值 1D 张量
where() 条件选择或获取位置索引 可广播,多功能

7. 注意事项

  • mask 必须是 torch.bool 类型;
  • 返回的张量是 一维的,无论原始张量是多少维;
  • 返回值是 新张量,原张量不变。

8. 总结

特性 说明
功能 从张量中按掩码提取元素
掩码类型 BoolTensor,可与输入广播
返回值 1D 张量(所有符合条件的值)
常见用途 条件过滤、loss 掩码、统计分析等

masked_select() 是深度学习中处理变长输入、掩码样本、注意力机制筛选等场景的常用函数