Pytorch actualiza algunos parámetros (parámetros de congelación) consideraciones

Pytorch experimental versión 1.2.0

Durante el proceso de entrenamiento, puede ser necesario corregir una parte de los parámetros del modelo y solo actualizar otra parte de los parámetros. Hay dos formas de lograr este objetivo: una es configurar la capa de red para que no actualice los parámetros a falso y la otra es pasar solo los parámetros que se actualizarán al definir el optimizador. Por supuesto, la mejor práctica es pasar solo los parámetros de requirements_grad = True en el optimizador, de modo que la memoria ocupada será menor y la eficiencia será mayor.

Uno, establezca el parámetro en falso

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=10):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    def forward(self, x):
        return self.fc2(self.fc1(x))


model = net()

# 冻结fc1层的参数
for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)  # 传入的是所有的参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0,10,[3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

Se puede ver en los resultados del experimento que, siempre que se establezcan los requisitos_grad = Falso, aunque se pasen todos los parámetros del modelo, solo se actualizarán los requisitos_grad = Verdadero.

Dos, pasar directamente los parámetros a actualizar

# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=3):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    def forward(self, x):
        return self.fc2(self.fc1(x))


model = net()

# 冻结fc1层的参数
# for name, param in model.named_parameters():
#     if "fc1" in name:
#         param.requires_grad = False


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)  # 只传入fc2的参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0,3,[3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()

Se puede ver que sólo se actualizan los parámetros pasados ​​por el optimizador, aunque los parámetros que no se pasaron se pueden derivar, los parámetros todavía no se actualizan.

Tres, la mejor forma de escribir:

Es combinar los dos anteriores, establecer el parámetro no actualizado en False y no pasar.

# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=3):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    def forward(self, x):
        return self.fc2(self.fc1(x))


model = net()

# 冻结fc1层的参数
for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0,3,[3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()

 

Supongo que te gusta

Origin blog.csdn.net/Answer3664/article/details/108493753
Recomendado
Clasificación