目标检测——在训练PyTorch模型遇到显存不足的情况怎么办?

1 前言

在目标检测中,可能会遇到显存不足的情况,我们在这里记录一下解决方案;

2 如何减少PyTorch模型需要的显存

2.1 减小batch_size的数量

最小的数量可以设置为2;

2.2 使用checkpoint对模型进行优化

我觉得checkpoint优化PyTorch模型的原理,主要在于不保存中间过程中的激活值,

checkpoint的文档中也是这样说的:

Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does not save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model.

用自洽性代码测试一下,为什么测试不成功?

网友ONE_SIX_MIX给出的简单例子:《torch.utils.checkpoint 简介 和 简易使用》

PyTorch官方在DenseNet中的写作实例,参见:vision/torchvision/models/densenet.py

2.3 可以减小分辨率

减小分辨率是节省显存的一种可行的做法,(但是南溪是不推荐这种做法的,如果实在显存不够,可以减小batch_size的大小,因为降低分辨率从本质上看,是对图像进行下采样,这样必然会丢失图像的原始信息,这是没有必要的);

注意:在下采样输入图像进行训练时,对于标注数据不能进行下采样,否则会出现意想不到的错误;

这是可以理解的,因为作为loss函数,他只是一个改卷老师,他并不关心模型内部进行了怎样的处理,

他只关心:对于评分标准而言,你的回答是否是正确的;

所以我们在进行训练的时候,是不能对标注信息进行任何修改的,因为我们怎么能修改评分标准呢;

我们应该做的就是把预测数据的还原操作作为模型的一部分,放入模型的前向运算中,而不是去修改标注数据;

发布了323 篇原创文章 · 获赞 97 · 访问量 34万+

猜你喜欢

转载自blog.csdn.net/songyuc/article/details/104754557