pytorch的tensor与numpy数组的对比

创建Tensor下表给了一些常用的作参考。

函数 功能
Tensor(*sizes) 基础构造函数
tensor(data,) 类似np.array的构造函数
ones(*sizes) 全1Tensor
zeros(*sizes) 全0Tensor
eye(*sizes) 对角线为1,其他为0
arange(s,e,step) 从s到e,步长为step
linspace(s,e,steps) 从s到e,均匀切分成steps份
rand/randn(*sizes) 均匀/标准分布
normal(mean,std)/uniform(from,to) 正态分布/均匀分布
randperm(m) 随机排列

创建Tensor和NumPy的数组

import numpy as np
import torch

# 类似于numpy
x = torch.arange(1,9,2)
print(x)
y = np.arange(1,9,2)
print(y)

# 创建5*3的未初始化tensor
x = torch.empty(5,3)
print(x)
y = np.empty([5,3])
print(y)

# 创建5*3的随机初始化tensor
x = torch.rand(5,3)
print(x)
y = np.random.rand(5,3)
print(y)

# 创建5*3的long型全0的tensor
x = torch.zeros(5,3,dtype=torch.long)
print(x)
y = np.zeros([5,3],dtype=np.long)
print(y)

# 创建4*4全为1的tensor
x = torch.ones(4,4)
print(x)
y = np.ones([4,4])
print(y)

# 创建4*4的单位tensor
x = torch.eye(4,4)
print(x)
y = np.eye(4,4)
print(y)

# 直接创建tensor
x = torch.tensor([[1,2,3],[4,5,6]])
print(x)
y = np.array([[1,2,3],[4,5,6]])
print(y)
x = torch.tensor(np.arange(12).reshape(3,4))
print(x)

# tensor一些属性
x = torch.tensor([[1,2],[3,4]])
print(x)
y = np.array([[1,2],[3,4]])
# 查看数据类型
print(x.dtype)
print(y.dtype)
# 查看形状
print(x.shape) # print(x.size())
print(x.size)
print(y.shape)
print(y.size)
tensor([1, 3, 5, 7])
[1 3 5 7]
tensor([[4.1327e-39, 8.9082e-39, 9.8265e-39],
        [9.4592e-39, 1.0561e-38, 1.0653e-38],
        [1.0469e-38, 9.5510e-39, 1.0102e-38],
        [8.4490e-39, 8.9082e-39, 8.4490e-39],
        [1.0194e-38, 1.0745e-38, 9.2755e-39]])
[[0.99537969 0.63659559 0.84111795]
 [0.12392895 0.84133229 0.44756633]
 [0.73434498 0.92331792 0.15648406]
 [0.7112982  0.59707987 0.14628696]
 [0.46465904 0.62735071 0.1993323 ]]
tensor([[0.8919, 0.5340, 0.8707],
        [0.7676, 0.7845, 0.3222],
        [0.3889, 0.1721, 0.4904],
        [0.7376, 0.8286, 0.8451],
        [0.2976, 0.2889, 0.5187]])
[[0.73119602 0.53034057 0.4751409 ]
 [0.13117025 0.30697661 0.21826581]
 [0.11899772 0.25654165 0.62288326]
 [0.36888543 0.08219027 0.65804774]
 [0.45289577 0.51710629 0.36201635]]
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
[[0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]
 [0 0 0]]
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
[[1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]
 [1. 1. 1. 1.]]
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]])
[[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]]
tensor([[1, 2, 3],
        [4, 5, 6]])
[[1 2 3]
 [4 5 6]]
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]], dtype=torch.int32)
tensor([[1, 2],
        [3, 4]])
torch.int64
int32
torch.Size([2, 2])
<built-in method size of Tensor object at 0x0000015D601B17C8>
(2, 2)
4

Tensor和NumPy相互转换

我们很容易用numpy()from_numpy()Tensor和NumPy中的数组相互转换。但是需要注意的一点是:
这两个函数所产生的的Tensor和NumPy中的数组共享相同的内存(所以他们之间的转换很快),改变其中一个时另一个也会改变!!!

还有一个常用的将NumPy中的array转换成Tensor的方法就是torch.tensor(), 需要注意的是,此方法总是会进行数据拷贝(就会消耗更多的时间和空间),所以返回的Tensor和原来的数据不再共享内存。

Tensor转NumPy

使用numpy()Tensor转换成NumPy数组:

a = torch.ones(5)
b = a.numpy()
print(a, b)

a += 1
print(a, b)
b += 1
print(a, b)

输出:

tensor([1., 1., 1., 1., 1.]) [1. 1. 1. 1. 1.]
tensor([2., 2., 2., 2., 2.]) [2. 2. 2. 2. 2.]
tensor([3., 3., 3., 3., 3.]) [3. 3. 3. 3. 3.]

NumPy数组转Tensor

使用from_numpy()将NumPy数组转换成Tensor:

import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
print(a, b)

a += 1
print(a, b)
b += 1
print(a, b)

输出:

[1. 1. 1. 1. 1.] tensor([1., 1., 1., 1., 1.], dtype=torch.float64)
[2. 2. 2. 2. 2.] tensor([2., 2., 2., 2., 2.], dtype=torch.float64)
[3. 3. 3. 3. 3.] tensor([3., 3., 3., 3., 3.], dtype=torch.float64)

所有在CPU上的Tensor(除了CharTensor)都支持与NumPy数组相互转换。

此外上面提到还有一个常用的方法就是直接用torch.tensor()将NumPy数组转换成Tensor,需要注意的是该方法总是会进行数据拷贝,返回的Tensor和原来的数据不再共享内存。

c = torch.tensor(a)
a += 1
print(a, c)

输出

[4. 4. 4. 4. 4.] tensor([3., 3., 3., 3., 3.], dtype=torch.float64)

参考资料:
https://github.com/ShusenTang/Dive-into-DL-PyTorch/blob/master/docs/chapter02_prerequisite/2.2_tensor.md

发布了50 篇原创文章 · 获赞 19 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/qq_28368377/article/details/103096377