Pytorch는로드하기 전에 학습 모델의 일부 매개 변수를로드하고 일부 고정 된 매개 변수 (실제 프로젝트 코드에서 측정 됨)를로드합니다.

내 요구 사항은 다양한 모델을 계속 시도하기 때문에 모델 블록이 항상 변경된다는 것입니다. 매번 훈련을 다시 시작하려면 많은 시간이 걸립니다.

이전에 실행 한 모델은 ResNet-> 세 개의 ResNet 매개 변수가 공유되었습니다.

                              ResNet-> 중급 모듈-> 결과

                              ResNet->

 

이제 매개 변수 공유없이 ResNet 1-> 3 개의 ResNet으로 변경하여 재 학습시키고 싶습니다. 이전 모델의 중간 모듈의 매개 변수를 가져오고 싶습니다.

                              ResNet 2-> 중급 모듈-> 결과    

                              ResNet 3->

중간 모듈의 매개 변수를 고정하면 훈련 속도가 빨라집니다.

두 위대한 신에 대한 두 개의 블로그 게시물을 참조하십시오 : 일부 매개 변수로드 https://blog.csdn.net/weixin_41519463/article/details/101604662 , 일부 매개 변수 고정 https://blog.csdn.net/jdzwanghao/article/ 세부 정보 / 83239111 .

구체적인 코드는 다음과 같습니다.

net = MY_Net( )
######导入部分参数
	model_dict = net.state_dict()
	for k, v in model_dict.items():
		print(k)

	pretrained_dict = torch.load(model_file1)#model_file1是之前模型的模型保存路径,这里只是加载参数而已
	for k, v in pretrained_dict.items():
		print(k)
	pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

	model_dict.update(pretrained_dict)  # 用预训练模型参数更新new_model中的部分参数

	net.load_state_dict(model_dict)  # 将更新后的model_dict加载进new model中


##### 冻结部分参数
	for param in net.parameters():
		param.requires_grad = False#设置所有参数不可导,下面选择设置可导的参数
	for param in net.ResNet1.parameters():
		param.requires_grad = True
	for param in net.ResNet2.parameters():
		param.requires_grad = True
	for param in net.ResNet3.parameters():
		param.requires_grad = True

optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr = 0.0001, momentum=0.90,weight_decay=0.0005)#关键是优化器中通filter来过滤掉那些不可导的参数

 

추천

출처blog.csdn.net/qq_36401512/article/details/105076090