Pytorch experimental versión 1.2.0
Durante el proceso de entrenamiento, puede ser necesario corregir una parte de los parámetros del modelo y solo actualizar otra parte de los parámetros. Hay dos formas de lograr este objetivo: una es configurar la capa de red para que no actualice los parámetros a falso y la otra es pasar solo los parámetros que se actualizarán al definir el optimizador. Por supuesto, la mejor práctica es pasar solo los parámetros de requirements_grad = True en el optimizador, de modo que la memoria ocupada será menor y la eficiencia será mayor.
Uno, establezca el parámetro en falso
import torch
import torch.nn as nn
import torch.optim as optim
# 定义一个简单的网络
class net(nn.Module):
def __init__(self, num_class=10):
super(net, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc2 = nn.Linear(4, num_class)
def forward(self, x):
return self.fc2(self.fc1(x))
model = net()
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2) # 传入的是所有的参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0,10,[3]).long()
output = model(x)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
Se puede ver en los resultados del experimento que, siempre que se establezcan los requisitos_grad = Falso, aunque se pasen todos los parámetros del modelo, solo se actualizarán los requisitos_grad = Verdadero.
Dos, pasar directamente los parámetros a actualizar
# 定义一个简单的网络
class net(nn.Module):
def __init__(self, num_class=3):
super(net, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc2 = nn.Linear(4, num_class)
def forward(self, x):
return self.fc2(self.fc1(x))
model = net()
# 冻结fc1层的参数
# for name, param in model.named_parameters():
# if "fc1" in name:
# param.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2) # 只传入fc2的参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0,3,[3]).long()
output = model(x)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()
Se puede ver que sólo se actualizan los parámetros pasados por el optimizador, aunque los parámetros que no se pasaron se pueden derivar, los parámetros todavía no se actualizan.
Tres, la mejor forma de escribir:
Es combinar los dos anteriores, establecer el parámetro no actualizado en False y no pasar.
# 定义一个简单的网络
class net(nn.Module):
def __init__(self, num_class=3):
super(net, self).__init__()
self.fc1 = nn.Linear(8, 4)
self.fc2 = nn.Linear(4, num_class)
def forward(self, x):
return self.fc2(self.fc1(x))
model = net()
# 冻结fc1层的参数
for name, param in model.named_parameters():
if "fc1" in name:
param.requires_grad = False
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
for epoch in range(10):
x = torch.randn((3, 8))
label = torch.randint(0,3,[3]).long()
output = model(x)
loss = loss_fn(output, label)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()