torch.Tensor与numpy.ndarray格式转换
torch.Tensor转换成numpy.ndarray
import torch
a = torch.ones(5)
print(a)
print(type(a)) #<class 'torch.Tensor'>
b = a.numpy()
print(b)
print(type(b)) #<class 'numpy.ndarray'>
numpy.ndarray转换成torch.Tensor
import numpy as np
import torch
a = np.ones(5)
b = torch.from_numpy(a)
print(a)
print(type(a)) #<class 'numpy.ndarray'>
print(b)
print(type(b)) #<class 'torch.Tensor'>