pytorch 将tensor 类型转为python中的常用数据类型

假设变量y是pytorch中的一个Tensor类型,如下所示:

y = torch.sum(m)
print(y)
print(type(y))
print(y.item())
print(type(y.item()))

则使用y.item()则可以将其转化为float类型,程序输出结果如下所示:

tensor(452.4124, device='cuda:0', grad_fn=<SumBackward0>)
<class 'torch.Tensor'>
452.4123840332031
<class 'float'>

不同版本的pytorch解决上述问题的方式可能还不一样,我这个是pytorch1.2

发布了36 篇原创文章 · 获赞 11 · 访问量 6540

猜你喜欢

转载自blog.csdn.net/t20134297/article/details/103850377