pytorch使用总结

转载链接 https://blog.csdn.net/tfygg/article/details/70227388

torch.Tensor - 一个多维数组

autograd.Variable - 改变Tensor并且记录下来操作的历史记录。和Tensor拥有相同的API,以及backward()的一些API。同时包含着和张量相关的梯度。

nn.Module - 神经网络模块。便捷的数据封装,能够将运算移往GPU,还包括一些输入输出的东西。

nn.Parameter - 一种变量,当将任何值赋予Module时自动注册为一个参数。

autograd.Function - 实现了使用自动求导方法的前馈和后馈的定义。每个Variable的操作都会生成至少一个独立的Function节点,与生成了Variable的函数相连之后记录下操作历史。

1、Tensors与numpy之间转换

  1. # 此处演示tensor和numpy数据结构的相互转换

  2. a = torch.ones(5)

  3. b = a.numpy()

  4. # 此处演示当修改numpy数组之后,与之相关联的tensor也会相应的被修改

  5. a.add_(1)

  6. print(a)

  7. print(b)

  8. # 将numpy的Array转换为torch的Tensor

  9. import numpy as np

  10. a = np.ones(5)

  11. b = torch.from_numpy(a)

  12. np.add(a, 1, out=a)

  13. print(a)

  14. print(b)

2、autograd.Variable 这是这个包中最核心的类。 它包装了一个Tensor,并且几乎支持所有的定义在其上的操作。一旦完成了你的运算,你可以调用 .backward()来自动计算出所有的梯度。

可以通过属性 .data 来访问原始的tensor,而关于这一Variable的梯度则集中于 .grad 属性中。

3、torch.nn 只接受小批量的数据

        整个torch.nn包只接受那种小批量样本的数据,而非单个样本。 例如,nn.Conv2d能够结构一个四维的TensornSamples x nChannels x Height x Width。如果你拿的是单个样本,使用input.unsqueeze(0)来加一个假维度就可以了。

4、数据读入

通常来讲,当你处理图像,声音,文本,视频时需要使用python中其他独立的包来将他们转换为numpy中的数组,之后再转换为torch.*Tensor。

(1)图像的话,可以用Pillow, OpenCV。

(2)声音处理可以用scipy和librosa。

(3)文本的处理使用原生Python或者Cython以及NLTK和SpaCy都可以。

特别的对于图像,我们有torchvision这个包可用,其中包含了一些现成的数据集如:Imagenet, CIFAR10, MNIST等等。同时还有一些转换图像用的工具。 这非常的方便并且避免了写样板代码。

5、模型的保存与加载

        torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict()。

  1. torch.save(net1, '7-net.pth') # 保存整个神经网络的结构和模型参数

  2. torch.save(net1.state_dict(), '7-net_params.pth') # 只保存神经网络的模型参数

        对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pth’)直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先导入对应的网络,通过net.load_state_dict(torch.load('.pth'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

  1. # 保存和加载整个模型

  2. torch.save(model_object, 'model.pkl')

  3. model = torch.load('model.pkl')

  4. # 仅保存和加载模型参数(推荐使用)

  5. torch.save(model_object.state_dict(), 'params.pkl')

  6. model_object.load_state_dict(torch.load('params.pkl'))

猜你喜欢

转载自blog.csdn.net/zhuimengshaonian66/article/details/81161752