Pytorchは、事前にトレーニングされたモデルを読み込み、微調整し、既存のモデルに独自のレイヤーを追加し、レイヤーごとに異なるパラメーターの更新を設定します

**事前トレーニングの読み込み:** 2つのタイプに分けられます:
1。以前にトレーニングしたモデルを読み込みます。

pretrained_params = torch.load('Pretrained_Model')
model = New_Model(xxx)
model.load_state_dict(pretrained_params.state_dict(),strict=False)

2.モデル
pytorchにロードし、残余ネットワーク18を例として取り上げます。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
pretrained_dict = resnet18.state_dict()
model_dict = new_model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict =  {
    
    k: v for k, v in pretrained_dict.items() if k in model_dict  and v.shape ==model_dict[k].shape}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
new_model.load_state_dict(model_dict)

微調整:

for name, value in model.named_parameters():
    if name 你需要固定的层:
        value.requires_grad = False
        
params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.Adam(params, lr=1e-5)

渡す必要のある固定レイヤーの名前がわからない場合:

print(model.state_dict().keys())

既存のモデルに独自のレイヤーを追加します
。1
最初に独自のモデルを作成します。2。既存のモデルパラメーターをモデルにロードします。ロード
されたResNet18を例として取り上げます。

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
pretrained_dict = resnet18.state_dict()
model_dict = new_model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict =  {
    
    k: v for k, v in pretrained_dict.items() if k in model_dict  and v.shape ==model_dict[k].shape}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
new_model.load_state_dict(model_dict)

レイヤーごとに異なるパラメーター更新を設定します。
エンコード層と他のレイヤーの学習率が1e-5で、デコーダー層の学習率が1e-3、1であるとします。最初にレイヤーをデコードします
。2。フィルターでフィルターします他のレイヤーのパラメーターを取得し、それらを配列としてオプティマイザーに渡します。

ignored_params = list(map(id, model.decoder.parameters()))
base_params = filter(lambda p: id(p) not in ignored_params, model.parameters())
optimizer = torch.optim.Adam([{
    
    'params':base_params,'lr':1e-5},                              {
    
    'params':model.decoder.parameters()}],lr=1e-3, momentum=0.8)                         

おすすめ

転載: blog.csdn.net/qq_43360777/article/details/106305469