深度学习框架_PyTorch_repeat()函数

PyTorch中的repeat()函数可以对张量进行复制。

当参数只有两个时,第一个参数表示的是复制后的列数,第二个参数表示复制后的行数。

当参数有三个时,第一个参数表示的是复制后的通道数,第二个参数表示的是复制后的列数,第三个参数表示复制后的行数。

接下来我们举一个例子来直观理解一下:

>>> x = torch.tensor([6,7,8])
>>> x.repeat(4,2)
tensor([[6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8],
        [6, 7, 8, 6, 7, 8]])
>>> x.repeat(4,2,1)
tensor([[[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]],

        [[6, 7, 8],
         [6, 7, 8]]])
>>> x.repeat(4,2,1).size()
torch.Size([4, 2, 3])

发布了95 篇原创文章 · 获赞 37 · 访问量 3415

猜你喜欢

转载自blog.csdn.net/Rocky6688/article/details/103916869
今日推荐