在科研过程中总结的一些琐碎的pytorch相关知识点。
1. 数据加载
-
锁页内存(
pin_memory
)是决定数据放在锁业内存还是硬盘的虚拟内存中,默认值为 False。如果设置为True,则表示数据放在锁业内存中。注意:显卡中的内存全部是锁页内存,所以放在锁页内存中可以加快读取速度。当计算机内存充足时,可将该值设置为 True。这一参数一般在data_loader()
函数中设置。 -
num_worker
取值最好是 2的幂次方-1, 如 0,1,3,7 等,因为会自动加 1. 默认值为 1. 这一参数一般在data_loader()
函数中设置。 -
GPU 利用率为低是因为显卡在等数据,解决办法(1)优化
data_loader()
函数;(2)增大batch size 等
2. 数据操作
- pytorch两个基本对象:
Tensor
(张量)和Variable
(变量)。 torch.Tensor
与torch.tensor
的区别:
torch.Tensor(data
):将数据转化torch.FloatTensor类型。
torch.tensor(data)
:根据数据类型或者dtype参数值将数据转化为torch.FloatTensor、torch.LongTensor、torch.DoubleTensor等类型。torch.contiguous()
:类似于 C++ 中的深拷贝。详解见此篇博客。torch.stack()
作用:用于连接大小相同的张量,并扩展维度,类比torch.cat()
. 注意:在哪个维度上操作,就将 dim 设置为哪个维度。 详解见此篇博客。- 使用
torch.zeros()
创建的张量默认在 CPU 上,如要在 GPU 上使用记得进行数据转移。 - 解决 torch 对象打印时有省略号的问题:
torch.set_printoptions(threshold=np.inf)
,该命令多用于打印完整日志。 - numpy 类型数据只能在 CPU 上运行。注意数据在torch类型与numpy类型间相互转换时数据的存放位置(如:不能将GPU上的张量数据直接转化为numpy类型数据)。
3. 模型操作
3.1 模式切换
model.eval()
与model.train()
区别在于是否启用 归一化层 + dropout,前者不启用,后者启用。
3.2 梯度更新
-
Module
中的层在定义时,相关Variable
的requires_grad
参数默认是True。而用户手动定义Variable
时,参数requires_grad
默认值是False,volatile
值也默认为False。volatile
的优先级比requires_grad
高,volatile
属性为True的节点不会求导(所以可以在测试阶段设置为 True)。 如果要修改可使用variable_name.require_grad_(True)
实现。 -
反向传播中梯度回传与更新的实现三步走: (1)
optimizer.zero_grad()
(梯度清零)(2)loss.backward()
(梯度回传)(3)optimizer.step()
(梯度更新) -
model.zero_grad ()
和optimizer.zero_grad ()
使用区别:当optimizer = optim.Optimizer (net.parameters ())
,即网络中参数均未冻结,全部需要更新时,二者等效,其中Optimizer可以是Adam、SGD等优化器;若网络中部分参数被冻结或多个网络共用同一个优化器,则二者不等价。详解见此篇博客。 -
with torch.no_grad()
作用:停止autograd
模块的工作,以起到加速和节省显存的作用。一般用在验证和测试阶段。注意:新版本Pytorch中,volatile
已被弃用,需替换为:with torch.no_grad()
.
3.3 模型保存与加载
torch.save(model, path)
:将训练好的模型 model 保存至 path 路径下。torch.load(model_path, map_location)
:将给定路径的预训练模型加载至指定设备上,详解见此篇博客。
参考资料
- Pytorch中contiguous()函数理解_.contiguous()_清晨的光明的博客-CSDN博客
- pytorch拼接函数:torch.stack()和torch.cat()–详解及例子_python torch拼接_紫芝的博客-CSDN博客
- pytorch之model.zero_grad() 与 optimizer.zero_grad()_models.zero_grad()_旺旺棒棒冰的博客-CSDN博客
- Pytorch:模型的保存与加载 torch.save()、torch.load()、torch.nn.Module.load_state_dict()_宁静致远*的博客-CSDN博客