pytorch update some parameters (freezing parameters) considerations

Experimental pytorch version 1.2.0

During the training process, it may be necessary to fix a part of the model parameters, and only update another part of the parameters. There are two ways to achieve this goal. One is to set the network layer not to update the parameters to false, and the other is to pass in only the parameters to be updated when defining the optimizer. Of course, the best practice is to pass in only the parameters of requirements_grad=True in the optimizer, so that the memory occupied will be smaller and the efficiency will be higher.

One, set the parameter to false

import torch
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=10):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    def forward(self, x):
        return self.fc2(self.fc1(x))


model = net()

# 冻结fc1层的参数
for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-2)  # 传入的是所有的参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0,10,[3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

It can be seen from the results of the experiment that as long as the requirements_grad=False are set, although all the parameters of the model are passed in, only the requirements_grad=True will be updated.

Two, directly pass in the parameters to be updated

# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=3):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    def forward(self, x):
        return self.fc2(self.fc1(x))


model = net()

# 冻结fc1层的参数
# for name, param in model.named_parameters():
#     if "fc1" in name:
#         param.requires_grad = False


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)  # 只传入fc2的参数
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0,3,[3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()

It can be seen that only the parameters passed in by the optimizer are updated. Although the parameters that are not passed in can be derived, the parameters are still not updated.

Three, the best way to write:

It is to combine the above two, set the non-updated parameter to False and not pass in at the same time.

# 定义一个简单的网络
class net(nn.Module):
    def __init__(self, num_class=3):
        super(net, self).__init__()
        self.fc1 = nn.Linear(8, 4)
        self.fc2 = nn.Linear(4, num_class)
    def forward(self, x):
        return self.fc2(self.fc1(x))


model = net()

# 冻结fc1层的参数
for name, param in model.named_parameters():
    if "fc1" in name:
        param.requires_grad = False


loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.fc2.parameters(), lr=1e-2)
print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)

for epoch in range(10):
    x = torch.randn((3, 8))
    label = torch.randint(0,3,[3]).long()
    output = model(x)

    loss = loss_fn(output, label)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print("model.fc1.weight", model.fc1.weight)
print("model.fc2.weight", model.fc2.weight)
print()

 

Guess you like

Origin blog.csdn.net/Answer3664/article/details/108493753