【torch】如何把把几个 tensor 连接起来?(含源代码)

一、cat

在 PyTorch 中,要向一个 tensor 中添加元素,你通常需要创建一个新的 tensor,然后将元素添加到新的 tensor 中。PyTorch tensors 是不可变的,所以不能像列表一样直接追加元素。以下是如何实现向 tensor 中添加元素的一种方法:

import torch

# 初始的 tensor
x = torch.tensor([0.6580, -1.0969, -0.4614])

# 要添加的元素
new_element = torch.tensor([1.0000, 2.0000, 3.0000])

# 创建一个新的 tensor,将原始 tensor 和新元素拼接在一起
result = torch.cat((x, new_element), dim=0)

print(result)

我们的输出结果是:

tensor([ 0.6580, -1.0969, -0.4614,  1.0000,  2.0000,  3.0000])

二、stack

我们上面的结果是:tensor([ 0.6580, -1.0969, -0.4614, 1.0000, 2.0000, 3.0000]),但是我如果想把结果转为下面的格式呢?

tensor([[ 0.6580, -1.0969, -0.4614],
        [ 1.0000,  2.0000,  3.0000]])

我该如何实现?可以使用stack函数。

import torch

x = torch.tensor([0.6580, -1.0969, -0.4614])
y = torch.tensor([1.0000, 2.0000, 3.0000])

# Concatenate x and y as rows (along dimension 0) to form a 2x3 tensor
result = torch.stack((x, y))

print(result)

我们的输出结果为:

tensor([[ 0.6580, -1.0969, -0.4614],
        [ 1.0000,  2.0000,  3.0000]])

但是如果我们继续添加,tensor的维度不同的时候,就会导致添加失败:

result = torch.stack((result, x))
result

代码就会报错:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
d:\CodeProject\37.帮助大家看代码\demo.ipynb 单元格 18 in 1
----> 1 result = torch.stack((result, x))
      2 result

RuntimeError: stack expects each tensor to be equal size, but got [2, 3] at entry 0 and [3] at entry 1

提示我们tensor的维度尺寸必须相同。

功能是不是有所限制?

三、cat如何进行逐行的合并?

如果你想逐个追加 tensor,可以使用 PyTorch 的 torch.cat 函数,在每次追加时创建一个新的 tensor。下面是一个示例,演示如何逐个追加 tensor:

import torch

# 创建一个空的 tensor
result = torch.tensor([])

# 逐个追加 tensor
x = torch.tensor([0.6580, -1.0969, -0.4614])
result = torch.cat((result, x.view(1, -1)), dim=0)

y = torch.tensor([1.0000, 2.0000, 3.0000])
result = torch.cat((result, y.view(1, -1)), dim=0)

# 打印结果
print(result)

在这个示例中,我们首先创建一个空的 tensor result,然后逐个追加 x 和 y。注意使用 view(1, -1) 来将每个 tensor 变形为一个行向量,以确保它们可以正确地连接在一起。最终,result 中包含了逐个追加的 tensor。

扫描二维码关注公众号,回复: 16475394 查看本文章

你可以继续使用类似的方法,逐个追加更多的 tensor。

上面的输出结果为:

tensor([[ 0.6580, -1.0969, -0.4614],
        [ 1.0000,  2.0000,  3.0000]])

我们测试一下是否可以逐个追加更多的tensor:

result = torch.cat((result, x.view(1, -1)), 0)
result

我们测试了四次,结果如下所示:

tensor([[ 0.6580, -1.0969, -0.4614],
        [ 1.0000,  2.0000,  3.0000],
        [ 0.6580, -1.0969, -0.4614],
        [ 0.6580, -1.0969, -0.4614],
        [ 0.6580, -1.0969, -0.4614],
        [ 0.6580, -1.0969, -0.4614]])

猜你喜欢

转载自blog.csdn.net/wzk4869/article/details/132710522
今日推荐