pytorch expand

基本用法

以下由chatgpt生成

PyTorch的expand方法是用于扩展张量维度的。它是Tensor的一个方法,可以用来复制张量的形状,把一个小张量复制成一个更大的张量。

最常见的使用场景是扩展某些维度,以适配其他张量的形状。例如,假设你有一个N维张量A,并且你想要在B维张量上对其进行操作,但B的维度大于A。在这种情况下,可以使用expand方法扩展A的维度,以适配B的维度。

具体地,该方法的用法如下:

expanded_tensor = tensor.expand(sizes)

其中,tensor是需要扩展的张量,sizes是一个整数列表,表示扩展后的形状。例如,如果tensor的形状为(3, 4),而sizes为(3, 4, 5),那么扩展后的张量的形状将为(3, 4, 5)。

简单示例:

import torch

a = torch.tensor([1, 2, 3])
b = a.expand(3, 3)
print(b)

# output:
# tensor([[1, 2, 3],
#         [1, 2, 3],
#         [1, 2, 3]])

在上面的例子中,我们通过使用expand方法,将1维张量a扩展为2维张量b,并通过指定sizes参数,指定了扩展后张量的形状为(3, 3)。

sizes含-1的情况

在参数列表sizes中如果含有-1,PyTorch会尽量匹配该维度大小,使其与原始维度相等。这样就能只拓展被制定了大小的维度。

简单示例:

import torch

a = torch.tensor([1, 2, 3, 4])
b = a.expand(2, -1)
print(b)

# output:
# tensor([[1, 2, 3, 4],
#         [1, 2, 3, 4]])

在上面的例子中,我们将张量a的形状扩展为(2, 4)。张量a沿着维度0拓展了两遍。

猜你喜欢

转载自blog.csdn.net/duoyasong5907/article/details/128991391