O PyTorch salva alguns parâmetros do modelo e os carrega em um novo modelo

Eu li recentemente o artigo e vi que alguns programas são para treinar um modelo primeiro, depois levar parte da estrutura do modelo treinado para o novo modelo e, em seguida, usar os novos dados para treinar o novo modelo, mas os parâmetros dessa parte são Para mantê-lo, a princípio pensei que era muito semelhante ao aprendizado de transferência, porque não o havia estudado em detalhes, então não tinha certeza, então primeiro aprendi como realizar o plano da tese e o registrei aqui para referência futura. Quanto ao aprendizado por transferência, vamos estudar no futuro, deve ser usado!

estado_dict

Introdução ao state_dict

state_dicté um objeto de dicionário Python que pode ser usado para salvar parâmetros de modelo, hiperparâmetros e informações de estado do otimizador (torch.optim). Deve-se observar que apenas as camadas com parâmetros que podem ser aprendidos (como camadas convolucionais, camadas lineares, etc.) possuem um state_dict.

Dê uma castanha para ilustrar o uso de state_dict:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
 
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
 
# 初始化模型
model = TheModelClass()
 
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

saída:

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

saída:

Optimizer's state_dict:
state 	 {
    
    }
param_groups 	 [{
    
    'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]

Salvar e carregar state_dict

O state_dict do modelo pode torch.save()ser salvo por , ou seja, apenas os parâmetros do modelo aprendidos são salvos e os load_state_dict()parâmetros do modelo podem ser carregados e restaurados por . As extensões de salvamento de modelos mais comuns no PyTorch são ' .pt ' ou ' .pth '.

Salve o modelo no caminho atual com o nome test_state_dict.pth

PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)
 
model = TheModelClass()    # 首先通过代码获取模型结构
model.load_state_dict(torch.load(PATH))   # 然后加载模型的state_dict
model.eval()

Nota : A função load_state_dict() apenas aceita 字典对象e não pode passar diretamente no caminho do modelo, então você precisa usar o arch.load() para desserializar o state_dict salvo primeiro.

Salvar e carregar modelos completos

# 保存完整模型
torch.save(model, PATH)
 
# 加载完整模型
model = torch.load(PATH)
model.eval()

Embora o código desse método pareça mais conciso do que o método state_dict, ele é menos flexível. Porque a função arch.save() usa o módulo pickle do Python para serialização, mas o pickle não pode salvar o modelo em si, mas salva o caminho do arquivo contendo a classe, que será usado quando o modelo for carregado. Portanto, quando o modelo é refatorado em outros projetos, erros inesperados podem aparecer.

OrderedDict

Se imprimirmos state_dicto tipo de dados, obteremos a seguinte saída:

print(type(model.state_dict()))

saída:

<class 'collections.OrderedDict'>

O módulo de coleções implementa contêineres específicos de objeto para fornecer uma alternativa aos contêineres integrados padrão do Python dict , list , set e tuple .

class collections.OrderedDict([items])

OrderedDictdictUma instância de uma subclasse , um dicionário ordenado é como um dicionário regular, mas com algumas funcionalidades extras relacionadas a operações de classificação.

Vale a pena mencionar que depois de python3.7 , a classe interna dict ganhou a capacidade de lembrar a ordem de inserção, portanto, esse contêiner não é tão importante.

Algumas diferenças de dict:

  1. Ditos regulares são projetados para serem muito bons em operações de mapeamento. O rastreamento do pedido de inserção é secundário;
  2. OrderedDict foi projetado para ser bom em operações de reordenamento. Eficiência de espaço, velocidade de iteração e desempenho de operações de atualização são secundários;
  3. Algoritmicamente, OrderedDict pode lidar com operações de reordenação frequentes melhor do que dict. Isso o torna adequado para rastrear acessos recentes (por exemplo, em um cache LRU);

Salve alguns parâmetros do modelo e carregue-os em um novo modelo

Para o modelo acima, o dicionário de estado do modelo é:

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Se quisermos apenas salvar os parâmetros treinados de conv1 , podemos fazer isso:

save_state = {
    
    }
print("Model's state_dict:")
for param_tensor in model.state_dict():
    if 'conv1' in param_tensor:
        save_state.update({
    
    param_tensor:torch.ones((model.state_dict()[param_tensor].size()))})
        print(param_tensor, "\t", model.state_dict()[param_tensor].size())

Aqui, para conveniência das demonstrações subsequentes, nossa frase-chave é escrita assim:

save_state.update({
    
    param_tensor:torch.ones((model.state_dict()[param_tensor].size()))})

Mas, ao salvar, devemos escrever assim:

save_state.update({
    
    param_tensor:model.state_dict()[param_tensor]})

Em seguida, salve o dicionário save_state :

PATH = './test_state_dict.pth'
torch.save(save_state, PATH)

Em seguida, carregue o novo modelo e atribua os parâmetros salvos ao novo modelo:

model = TheModelClass()    # 首先通过代码获取模型结构
model.load_state_dict(torch.load(PATH), strict=False)   # 然后加载模型的state_dict

saída:

_IncompatibleKeys(missing_keys=['conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'], unexpected_keys=[])

Aqui está o modo de inicialização a quente, definindo o parâmetro estrito como Falso na função load_state_dict() para ignorar os parâmetros das chaves não correspondentes.

Vejamos novamente os parâmetros do novo modelo:

model.state_dict()['conv1.bias']

saída:

tensor([1., 1., 1., 1., 1., 1.])

Os parâmetros salvos entre descobertas foram carregados no novo modelo.

Veja os outros parâmetros no modelo:

model.state_dict()['conv2.bias']

saída:

tensor([ 0.0468,  0.0024, -0.0510,  0.0791,  0.0244, -0.0379, -0.0708,  0.0317,
        -0.0410, -0.0238,  0.0071,  0.0193, -0.0562, -0.0336,  0.0109, -0.0323])

Você pode ver que outros parâmetros são normais!

A diferença entre state_dict(), named_parameters(), model.parameter(), named_modules()

model.state_dict()

state_dict()É armazenar layer_name e layer_param como chaves na forma de dict . Contém os nomes e parâmetros de todas as camadas, os parâmetros do modelo armazenado tensor 的 require_grad 属性都是 False. O valor de saída não inclui require_grad. Você não pode usar model.state_dict() para obter parâmetros e definir o atributo require_grad ao corrigir uma determinada camada .

import torch
import torch.nn as nn
import torch.optim as optim
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.bn = nn.BatchNorm2d(num_features=2)
        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8, 4)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool(x)
        x = x.view(-1, 8)
        x = self.fc1(x)
        x = self.softmax(x)
        return x
 
# 初始化模型
model = TheModelClass()
 
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for param_tensor in model.state_dict():
    print(param_tensor, "\n", model.state_dict()[param_tensor])

saída:

conv1.weight 
 tensor([[[[ 0.2438, -0.0467,  0.0486],
          [-0.1932, -0.2083,  0.3239],
          [ 0.1712,  0.0379, -0.2381]]],


        [[[ 0.2853,  0.0961,  0.0809],
          [ 0.2526,  0.3138, -0.2243],
          [-0.1627, -0.2958, -0.1995]]]])    # 没有 require_grad
conv1.bias 
 tensor([-0.3287, -0.0686])
bn.weight 
 tensor([1., 1.])
bn.bias 
 tensor([0., 0.])
bn.running_mean 
 tensor([0., 0.])
bn.running_var 
 tensor([1., 1.])
bn.num_batches_tracked 
 tensor(0)
fc1.weight 
 tensor([[ 0.2246, -0.1272,  0.0163, -0.3089,  0.3511, -0.0189,  0.3025,  0.0770],
        [ 0.2964,  0.2050,  0.2879,  0.0237, -0.3424,  0.0346, -0.0659, -0.0115],
        [ 0.1960, -0.2104, -0.2839,  0.0977, -0.2857, -0.0610, -0.3029,  0.1230],
        [-0.2176,  0.2868, -0.2258,  0.2992, -0.2619,  0.3286,  0.0410,  0.0152]])
fc1.bias 
 tensor([-0.0623,  0.1708, -0.1836, -0.1411])

model.named_parameters()

named_parameters()É para empacotar layer_name e layer_param em uma tupla e depois armazená-la na lista.
Salve apenas os parâmetros que podem ser aprendidos e atualizados. model.named_parameters() Parâmetros de modelo armazenados tensor 的 require_grad 属性都是True. Geralmente é usado para corrigir se os parâmetros de uma determinada camada são treinados , geralmente por meio de model.named_parameters() para obter parâmetros e definir o atributo require_grad .

import torch
import torch.nn as nn
import torch.optim as optim
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.bn = nn.BatchNorm2d(num_features=2)
        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8, 4)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool(x)
        x = x.view(-1, 8)
        x = self.fc1(x)
        x = self.softmax(x)
        return x
 
# 初始化模型
model = TheModelClass()
 
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for layer_name, layer_param in model.named_parameters():
    print(layer_name, "\n", layer_param)

saída:

conv1.weight 
 Parameter containing:
tensor([[[[ 0.2438, -0.0467,  0.0486],
          [-0.1932, -0.2083,  0.3239],
          [ 0.1712,  0.0379, -0.2381]]],


        [[[ 0.2853,  0.0961,  0.0809],
          [ 0.2526,  0.3138, -0.2243],
          [-0.1627, -0.2958, -0.1995]]]], requires_grad=True)    # require_grad为True
conv1.bias 
 Parameter containing:
tensor([-0.3287, -0.0686], requires_grad=True)
bn.weight 
 Parameter containing:
tensor([1., 1.], requires_grad=True)
bn.bias 
 Parameter containing:
tensor([0., 0.], requires_grad=True)
fc1.weight 
 Parameter containing:
tensor([[ 0.2246, -0.1272,  0.0163, -0.3089,  0.3511, -0.0189,  0.3025,  0.0770],
        [ 0.2964,  0.2050,  0.2879,  0.0237, -0.3424,  0.0346, -0.0659, -0.0115],
        [ 0.1960, -0.2104, -0.2839,  0.0977, -0.2857, -0.0610, -0.3029,  0.1230],
        [-0.2176,  0.2868, -0.2258,  0.2992, -0.2619,  0.3286,  0.0410,  0.0152]],
       requires_grad=True)
fc1.bias 
 Parameter containing:
tensor([-0.0623,  0.1708, -0.1836, -0.1411], requires_grad=True)

modelo.parâmetro()

parameter()Somente os parâmetros são retornados, layer_name não está incluído . 返回结果包含 require_grad,且均为 Ture, principalmente porque os parâmetros padrão precisam ser aprendidos quando a rede é criada, ou seja, require_grad é tudo True.

import torch
import torch.nn as nn
import torch.optim as optim
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.bn = nn.BatchNorm2d(num_features=2)
        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8, 4)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool(x)
        x = x.view(-1, 8)
        x = self.fc1(x)
        x = self.softmax(x)
        return x
 
# 初始化模型
model = TheModelClass()
 
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for layer_param in model.parameters():
    print(layer_param)

saída:

Parameter containing:
tensor([[[[ 0.2438, -0.0467,  0.0486],
          [-0.1932, -0.2083,  0.3239],
          [ 0.1712,  0.0379, -0.2381]]],


        [[[ 0.2853,  0.0961,  0.0809],
          [ 0.2526,  0.3138, -0.2243],
          [-0.1627, -0.2958, -0.1995]]]], requires_grad=True)
Parameter containing:
tensor([-0.3287, -0.0686], requires_grad=True)
Parameter containing:
tensor([1., 1.], requires_grad=True)
Parameter containing:
tensor([0., 0.], requires_grad=True)
Parameter containing:
tensor([[ 0.2246, -0.1272,  0.0163, -0.3089,  0.3511, -0.0189,  0.3025,  0.0770],
        [ 0.2964,  0.2050,  0.2879,  0.0237, -0.3424,  0.0346, -0.0659, -0.0115],
        [ 0.1960, -0.2104, -0.2839,  0.0977, -0.2857, -0.0610, -0.3029,  0.1230],
        [-0.2176,  0.2868, -0.2258,  0.2992, -0.2619,  0.3286,  0.0410,  0.0152]],
       requires_grad=True)
Parameter containing:
tensor([-0.0623,  0.1708, -0.1836, -0.1411], requires_grad=True)

model.named_modules()

Retorna o nome e a estrutura de cada modelo de camada

import torch
import torch.nn as nn
import torch.optim as optim
 
# 定义模型
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(1, 2, 3)
        self.bn = nn.BatchNorm2d(num_features=2)
        self.act = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(8, 4)
        self.softmax = nn.Softmax(dim=-1)
 
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool(x)
        x = x.view(-1, 8)
        x = self.fc1(x)
        x = self.softmax(x)
        return x
 
# 初始化模型
model = TheModelClass()
 
# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for name, module in model.named_modules():
    print(name,'\n', module)

saída:

 TheModelClass(
  (conv1): Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
  (bn): BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (act): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (fc1): Linear(in_features=8, out_features=4, bias=True)
  (softmax): Softmax(dim=-1)
)
conv1 
 Conv2d(1, 2, kernel_size=(3, 3), stride=(1, 1))
bn 
 BatchNorm2d(2, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
act 
 ReLU()
pool 
 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
fc1 
 Linear(in_features=8, out_features=4, bias=True)
softmax 
 Softmax(dim=-1)

Congele certas camadas / deixe apenas algumas camadas aprenderem

No momento, existe esse requisito: treinei uma rede com uma grande quantidade de dados e preciso usar essa rede treinada para testar a precisão de novos assuntos. Quando o modelo de rede permanecer inalterado, quero usar a migração A ideia do aprendizado é pegar os parâmetros treinados e só precisa usar uma quantidade muito pequena de dados dos novos assuntos para treinar a cabeça de classificação do modelo, sendo que as demais camadas do modelo não precisam ser treinadas. Desta forma preciso congelar algumas camadas do modelo, ou seja, apenas treinar algumas camadas do modelo.

Através da análise acima, utilizo para state_dict()ler os parâmetros do modelo e salvá-los:

model_dict = model.state_dict()

Porque state_dict()os parâmetros do modelo obtidos usando a função não têm require_gradatributos, e o artigo também disse que state_dict()os atributos require_grad do tensor dos parâmetros do modelo armazenado são todos False .

Em seguida, excluímos os parâmetros da camada a ser treinada nos parâmetros do modelo salvo, pois somente após a exclusão, após criar um novo objeto modelo e carregar os parâmetros do modelo anterior, os parâmetros da camada a ser treinada não serão excluído. A cobertura de parâmetro do modelo.

model_dict.pop('fc1.weight', None)

saída:

tensor([[ 0.2246, -0.1272,  0.0163, -0.3089,  0.3511, -0.0189,  0.3025,  0.0770],
        [ 0.2964,  0.2050,  0.2879,  0.0237, -0.3424,  0.0346, -0.0659, -0.0115],
        [ 0.1960, -0.2104, -0.2839,  0.0977, -0.2857, -0.0610, -0.3029,  0.1230],
        [-0.2176,  0.2868, -0.2258,  0.2992, -0.2619,  0.3286,  0.0410,  0.0152]])
model_dict.pop('fc1.bias', None)
tensor([-0.0623,  0.1708, -0.1836, -0.1411])

Em seguida, imprimimos os parâmetros do modelo salvo:

for param_tensor in model_dict:
    print(param_tensor, "\n", model_dict[param_tensor])

saída:

conv1.weight 
 tensor([[[[ 0.2438, -0.0467,  0.0486],
          [-0.1932, -0.2083,  0.3239],
          [ 0.1712,  0.0379, -0.2381]]],


        [[[ 0.2853,  0.0961,  0.0809],
          [ 0.2526,  0.3138, -0.2243],
          [-0.1627, -0.2958, -0.1995]]]])
conv1.bias 
 tensor([-0.3287, -0.0686])
bn.weight 
 tensor([1., 1.])
bn.bias 
 tensor([0., 0.])
bn.running_mean 
 tensor([0., 0.])
bn.running_var 
 tensor([1., 1.])
bn.num_batches_tracked 
 tensor(0)

Descobriu-se que os parâmetros da camada que excluímos desapareceram.

Em seguida, criamos um novo objeto de modelo e carregamos os parâmetros do modelo salvo anteriormente no novo objeto de modelo:

model_ = TheModelClass()
model_.load_state_dict(model_dict, strict=False)

saída:

_IncompatibleKeys(missing_keys=['fc1.weight', 'fc1.bias'], unexpected_keys=[])

Então vamos ver como são as propriedades dos parâmetros do novo objeto modelo require_grad:

model_dict_ = model_.named_parameters()
para layer_name, layer_param em model_dict_ :
print(layer_name, “\n”, layer_param)

saída:

conv1.weight 
 Parameter containing:
tensor([[[[ 0.2438, -0.0467,  0.0486],
          [-0.1932, -0.2083,  0.3239],
          [ 0.1712,  0.0379, -0.2381]]],


        [[[ 0.2853,  0.0961,  0.0809],
          [ 0.2526,  0.3138, -0.2243],
          [-0.1627, -0.2958, -0.1995]]]], requires_grad=True)
conv1.bias 
 Parameter containing:
tensor([-0.3287, -0.0686], requires_grad=True)
bn.weight 
 Parameter containing:
tensor([1., 1.], requires_grad=True)
bn.bias 
 Parameter containing:
tensor([0., 0.], requires_grad=True)
fc1.weight 
 Parameter containing:
tensor([[-0.2306, -0.3159, -0.3105, -0.3051,  0.2721, -0.0691,  0.2208, -0.1724],
        [-0.0238, -0.1555,  0.2341, -0.2668,  0.3143,  0.1433,  0.3140, -0.2014],
        [ 0.0696, -0.0250,  0.0316, -0.1065,  0.2260, -0.1009, -0.1990, -0.1758],
        [-0.1782, -0.2045, -0.3030,  0.2643,  0.1951, -0.2213, -0.0040,  0.1542]],
       requires_grad=True)
fc1.bias 
 Parameter containing:
tensor([-0.0472, -0.0569, -0.1912, -0.2139], requires_grad=True)

require_gradPodemos ver que os parâmetros do modelo anterior foram carregados no novo objeto modelo, mas os atributos dos novos parâmetros são todos True , o que não é o que queremos.

A partir da análise acima, podemos ver que state_dict()não podemos obter o efeito que desejamos lendo os parâmetros do modelo, salvando-os e, em seguida, carregando-os em um novo objeto de modelo. Também precisamos de algumas outras operações para completar o objetivo.

Podemos resolver o problema acima de duas maneiras:

require_grad=Falso

Podemos definir as propriedades dos parâmetros das camadas que não precisam ser aprendidas require_gradpara Falso

model_dict_ = model_.named_parameters()
for layer_name, layer_param in model_dict_:
    if 'fc1' in layer_name:
        continue
    else:
        layer_param.requires_grad = False

Então, olhamos para os parâmetros do modelo:

for layer_param in model_.parameters():
    print(layer_param)

saída:

Parameter containing:
tensor([[[[ 0.2438, -0.0467,  0.0486],
          [-0.1932, -0.2083,  0.3239],
          [ 0.1712,  0.0379, -0.2381]]],


        [[[ 0.2853,  0.0961,  0.0809],
          [ 0.2526,  0.3138, -0.2243],
          [-0.1627, -0.2958, -0.1995]]]])
Parameter containing:
tensor([-0.3287, -0.0686])
Parameter containing:
tensor([1., 1.])
Parameter containing:
tensor([0., 0.])
Parameter containing:
tensor([[ 0.0182,  0.1294,  0.0250, -0.1819, -0.2250, -0.2540, -0.2728,  0.2732],
        [ 0.0167, -0.0969,  0.1498, -0.1844,  0.1387,  0.2436,  0.1278, -0.1875],
        [-0.0408,  0.0786,  0.2352,  0.0277,  0.2571,  0.2782,  0.2505, -0.2454],
        [ 0.3369, -0.0804,  0.2677,  0.0927,  0.0433,  0.1716, -0.1870, -0.1738]],
       requires_grad=True)
Parameter containing:
tensor([0.1084, 0.3018, 0.1211, 0.1081], requires_grad=True)

Podemos ver que as propriedades dos parâmetros das camadas que não precisam ser aprendidas require_gradforam todas alteradas para False .

Em seguida, esses parâmetros podem ser enviados para o otimizador:

optimizer = optim.SGD(model_.parameters(), lr=0.001, momentum=0.9)

Definir parâmetros de atualização do otimizador

Se você não deseja atualizar uma determinada camada de rede, a maneira mais simples é não colocar os parâmetros da camada de rede no otimizador:

optimizer = optim.SGD(model_.fc1.parameters(), lr=0.001, momentum=0.9)

Nota: Os parâmetros que estão congelados neste momento ainda estão derivando durante a retropropagação, mas os parâmetros não são atualizados.

Pode-se observar que se este método for adotado, o uso de memória pode ser reduzido, e ao mesmo tempo, se for usado com antecedência, o require_grad=Falsemodelo irá pular parâmetros que não precisam ser calculados e melhorar a velocidade de cálculo, então esses dois métodos podem ser usados ​​juntos.

Referências

Notas de estudo do PyTorch: use state_dict para salvar e carregar modelos

Coleções avançadas de contêiner Python – OrderedDict

Carregamento do modelo de pré-treinamento Pytorch, modificando a estrutura da rede e corrigindo uma certa camada de treinamento de parâmetros, diferentes camadas usam diferentes taxas de aprendizado

Acho que você gosta

Origin blog.csdn.net/qq_41990294/article/details/128942601
Recomendado
Clasificación