【Pytorch】张量索引None的含义

在 PyTorch 中,使用 None 作为张量索引的目的是为张量增加一个新的维度。具体来说,None 索引会创建一个大小为1的新维度,而不会改变原始张量中的任何数据。

假设 hidden_states 是一个形状为 (A, B) 的二维张量。当你使用 hidden_states[None, :] 进行索引时,你将得到一个形状为 (1, A, B) 的三维张量,其中原始数据保持不变,并在最前面增加了一个新维度。

以下是一个具体示例:

import torch

# 创建一个形状为 (3, 4) 的二维张量
hidden_states = torch.randn(3, 4)
print("原始张量:")
print(hidden_states)
print("原始张量形状:", hidden_states.shape)

# 使用 None 索引增加一个维度
new_hidden_states = hidden_states[None, :]
print("增加维度后的张量:")
print(new_hidden_states)
print("增加维度后的张量形状:", new_hidden_states.shape)

输出:

原始张量:
tensor([[ 0.8860,  0.2456,  0.2150, -0.4519],
        [ 0.8445, -0.5865, -0.0738,  0.1211],
        [-0.4374,  0.3314, -0.0214,  0.4563]])
原始张量形状: torch.Size([3, 4])
增加维度后的张量:
tensor([[[ 0.8860,  0.2456,  0.2150, -0.4519],
         [ 0.8445, -0.5865, -0.0738,  0.1211],
         [-0.4374,  0.3314, -0.0214,  0.4563]]])
增加维度后的张量形状: torch.Size([1, 3, 4])

如你所见,原始张量的形状从 (3, 4) 变为 (1, 3, 4),在最前面增加了一个新维度。这在需要调整张量维度以匹配其他张量形状的情况下非常有用。例如,当你需要将张量输入到期望三维输入的神经网络层时。

猜你喜欢

转载自blog.csdn.net/qq_56199570/article/details/129867119