之前在使用Pytorch训练模型的时候一直用的是单卡或者是DataParallel,并没有使用官方推荐的DistributedDataParallel,因为配置起来略显麻烦(懒癌晚期。。。)。最近在修改网上的一个开源模型,该代码使用到了DistributedDataParallel进行分布式训练,在修改的过程中学习了一些DistributedDataParallel的知识,特地记录一下。
DistributedDataParallel的好处
这个不用多讲,既然是官网推荐的方式,肯定是最好的。不过就是配置起来比较麻烦,而且对新手来说较单卡或者DataParallel有不少隐藏的坑。
DistributedDataParallel的坑
1、DDP包装过后的模型参数名称
首先第一个坑就是模型参数名称。一般来说,我们定义一个模型的时候都会给每个子模块定义一个名称。比如:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 1) #定义一个名为conv1的卷积层
def forward(self, x):
out = self.conv1(x)
return out
在Net这个模型中,我们定义了一个名为self.conv1的卷积层。
我们在训练的时候一般都会保存一些中间的训练权重文件,因为有时候会出现意外的情况导致训练中止,这样我们就可以加载某一时间点的权重继续训练,避免冲头开始训练浪费大量时间。
一般来说,加载保存的中间权重继续训练大致步骤如下:
net = Net() #实例化一个Net模型,这个模型是一个单卡的模型
if args.checkpoint: #args.checkpoint是权重路径
print("loading pytorch ckpt...", args.checkpoint)
cpu_device = torch.device("cpu")
ckpt = torch.load(args.checkpoint, map_location=cpu_device) #将保存的权重映射到cpu上
net.load_state_dict(ckpt,strict=False) #将训练的权重字典加载到模型net上
对于单卡或者DataParallel包装的模型来说,这样的加载权重方式没有太大问题。
但是,对于DDP包装过的模型来说,这样加载出来的模型在继续训练的时候就会出现loss很大的情况,跟从头开始训练几乎没有分别!!!
[Epoch 150/300][Iter 0/1034][lr 0.000052][Loss: anchor 5.96, iou 5.31, l1 32.46, conf 22200.45, cls 55.65, imgsize 608, time: 5.66]
#这是我刚开始从Epoch150继续训练显示的loss,跟从Epoch0开始训练的loss量级一样。
这种情况在我有限的调参经验来看,大概率就是模型权重没有正确加载,经过一番面向搜索引擎学习之后,得知这是DistributedDataParallel的一个坑,就是DistributedDataParallel包装模型之后会给模型参数名字前面自动加上一个module.!!!
在我将net.load_state_dict(ckpt, strict=False) 中的strict改为True之后模型就不能正确加载了,提示一下错误:
RuntimeError: Error(s) in loading state_dict for YOLOv3:
Missing key(s) in state_dict: "conv1.weight", .....
大致意思就是丢失了模型的conv1.weight等key,为啥呢?在我将保存的权重参数ckpt的key打印出来就知道了,ckpt的参数key已经变成了module.conv1.weight,在strict=True模式下模型和权重的key不同自然就不能正确加载了。
而在第一种strict=False情况下,二者key不同pytorch就不加载这个key的权重而是默认初始化了模型的这个key的权重,这就是为什么在第一种情况下加载150 epoch权重继续训练的loss和从头开始训练一样大,因为加载的结果几乎就是默认初始化这个模型,相当于白加载了。
所以,要怎么才能正确地将权重加载到模型上从而继续训练呢?
最intuitive的想法自然就是将DistributedDataParallel加到模型参数名称前面的**module.**这7个字符去掉就行啦!
我的做法就是:
net = Net() #实例化一个Net模型,这个模型是一个单卡的模型
if args.checkpoint: #args.checkpoint是权重路径
print("loading pytorch ckpt...", args.checkpoint)
cpu_device = torch.device("cpu")
ckpt = torch.load(args.checkpoint, map_location=cpu_device) #将保存的权重映射到cpu上
######################## here ############################
from collections import OrderedDict
sd = OrderedDict()
for item in ckpt.items():
sd_key = item[0][7:] #remove 'module.'
sd[sd_key] = item[1]
model.load_state_dict(sd,strict=True)
##########################################################
具体做法是:新建一个OrderedDict,将ckpt的每个带有module. 前缀的key去掉这个前缀得到新的sd_key,并将sd_key作为OrderedDict的key,把带有前缀的key对应的值value赋给新的sd_key,这样就得到了一个和ckpt一一对应的新的OrderedDict,这个Dict的权重就是之前保存权重,只是key的名称没有了module. 前缀。这样把这个新的Dict加载到实例化出来的模型继续训练就没有问题啦!!
这是我修改加载方式后从150Epoch继续训练的loss:
[Epoch 150/300][Iter 0/1034][lr 0.000052][Loss: anchor 2.88, iou 3.02, l1 11.04, conf 5.71, cls 19.27, imgsize 608, time: 5.53]
可看到,loss与之前相比小了很多,基本与我训练日志上的loss一致。
2、DDP训练时出现 RuntimeError: Address already in use
出现这个问题原因是已经有人在机器上已默认端口使用了分布式训练,这里要做的是在训练时指定一个没有使用的端口就可以,例如
python -m torch.distributed.launch --nproc_per_node=4 --master_port=23455 train.py