pytorch之view函数

改变形状

注意 view() 返回的新tensor与源tensor共享内存(其实是同⼀个tensor),也即更改其中的⼀个,另 外⼀个也会跟着改变。(顾名思义,view仅是改变了对这个张量的观察角度)

代码: 

x = torch.FloatTensor([[1, 2, 3], [2, 3, 4], [3, 4, 5], [4, 5, 6], [5, 6, 7]])
y = x.view(15)
z = y.view(-1, 5)

输出:

torch.Size([5, 3]) torch.Size([15]) torch.Size([3, 5])

代码:

x += 1
print("x=", x)
print("y=", y)

​​​​​​​输出:

x= tensor([[2., 3., 4.],
        [3., 4., 5.],
        [4., 5., 6.],
        [5., 6., 7.],
        [6., 7., 8.]])
y= tensor([2., 3., 4., 3., 4., 5., 4., 5., 6., 5., 6., 7., 6., 7., 8.])

所以如果我们想返回⼀个真正新的副本(即不共享内存)该怎么办呢?Pytorch还提供了⼀ 个 reshape() 可以改变形状,但是此函数并不能保证返回的是其拷贝,所以不推荐使用。推荐先 ⽤ clone 创造一个副本然后再使⽤ view 。

​​​​​​​代码:

x_new = x.clone().view(15)
x -= 1
print("x=", x)
print("x_new=", x_new)

输出:

x= tensor([[1., 2., 3.],
        [2., 3., 4.],
        [3., 4., 5.],
        [4., 5., 6.],
        [5., 6., 7.]])
x_new= tensor([2., 3., 4., 3., 4., 5., 4., 5., 6., 5., 6., 7., 6., 7., 8.])

__________________________________________________________________________________

今天就水到这里

猜你喜欢

转载自blog.csdn.net/qq_38890412/article/details/107621664