06_Entrenamiento del modelo PyTorch [tasa de aprendizaje y clase base del optimizador]

Cuando se determinan los datos, el modelo y la función de pérdida, se ha determinado el modelo matemático de la tarea y luego se debe seleccionar uno adecuado.
El Optimizer optimiza el modelo.
Todos los optimizadores en PyTorch (como: optim.Adadelta, optim.SGD, optim.RMSprop, etc.) son
Una subclase de Optimizer, algunos métodos de uso común se definen en Optimizer, incluidos zero_grad(), step(closure), state_dict(), load_state_dict(state_dict) y add_param_group(param_group)
La gestión de parámetros del optimizador se basa en el concepto de grupos , y se pueden configurar lr, momento, peso_decaimiento , etc. específicos para cada grupo de parámetros.
El grupo de parámetros se representa como una lista (self.param_groups) en el optimizador, donde cada elemento es un dict, que representa un parámetro y su configuración correspondiente, incluidos 'params', 'weight_decay', 'lr', 'momentum', etc. en el campo de dictado.

1. Código concepto básico

import torch
import torch.optim as optim


w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

w3 = torch.randn(2, 2)
w3.requires_grad = True

print("w1",w1)
print("w2",w2)
print("w3",w3)

# 一个参数组
optimizer_1 = optim.SGD([w1, w3], lr=0.1)
print('len(optimizer.param_groups): ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

# 两个参数组
optimizer_2 = optim.SGD([{'params': w1, 'lr': 0.1},
                         {'params': w2, 'lr': 0.001}])
print('len(optimizer.param_groups): ', len(optimizer_2.param_groups))
print(optimizer_2.param_groups)

2.  grado_cero()

Función: borrar el gradiente a cero. Esto se hace antes de cada actualización, ya que PyTorch no pone a cero automáticamente los gradientes.

Código y salida:

import torch
import torch.optim as optim

# ----------------------------------- zero_grad

w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

optimizer = optim.SGD([w1, w2], lr=0.001, momentum=0.9)

print(optimizer.param_groups)
print("=======================")
print(optimizer.param_groups[0])
print("=======================")
print(optimizer.param_groups[0]['params'])
print("=======================")
print(optimizer.param_groups[0]['params'][0])  #参数w1

optimizer.param_groups[0]['params'][0].grad = torch.randn(2, 2)  

print('参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad, '\n')  # 参数组,第一个参数(w1)的梯度

optimizer.zero_grad()
print('执行zero_grad()之后,参数w1的梯度:')
print(optimizer.param_groups[0]['params'][0].grad)  # 参数组,第一个参数(w1)的梯度

 

3. estado_dict()

Función: Obtiene los parámetros actuales del modelo y los devuelve en forma de diccionario ordenado.

En este diccionario ordenado, la clave es el nombre del parámetro de cada capa y el valor es el parámetro.
Código y salida:
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- state_dict
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)  #输出一个特征图,需要3个 3*3 的矩阵
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x


net = Net()

# 获取网络当前参数
net_state_dict = net.state_dict()

print('net_state_dict类型:', type(net_state_dict))
print('net_state_dict管理的参数: ', net_state_dict.keys())
for key, value in net_state_dict.items():
    print('参数名: ', key, '\t大小: ',  value.shape)

4. agregar_param_grupo() 

efecto:

Agregue un conjunto de parámetros al grupo de parámetros administrado por el optimizador, y puede personalizar lr, impulso, peso_decaimiento , etc. para este conjunto de parámetros , que se usan comúnmente con precisión.

Código y salida:

# coding: utf-8

import torch
import torch.optim as optim

# ----------------------------------- add_param_group

w1 = torch.randn(2, 2)
w1.requires_grad = True

w2 = torch.randn(2, 2)
w2.requires_grad = True

w3 = torch.randn(2, 2)
w3.requires_grad = True

# 一个参数组
optimizer_1 = optim.SGD([w1, w2], lr=0.1)
print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

# 增加一个参数组
print('增加一组参数 w3\n')
optimizer_1.add_param_group({'params': w3, 'lr': 0.001, 'momentum': 0.8})

print('当前参数组个数: ', len(optimizer_1.param_groups))
print(optimizer_1.param_groups, '\n')

print('可以看到,参数组是一个list,一个元素是一个dict,每个dict中都有lr, momentum等参数,这些都是可单独管理,单独设定,十分灵活!')

 5. cargar_estado_dict(estado_dict)

efecto:

Cargue los parámetros en state_dict a la red actual, a menudo se usa en ajuste fino.

Código y salida:

import torch
import torch.nn as nn
import torch.nn.functional as F


# ----------------------------------- load_state_dict

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 1, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(1 * 3 * 3, 2)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 1 * 3 * 3)
        x = F.relu(self.fc1(x))
        return x

    def zero_param(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.constant_(m.weight.data, 0)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.constant_(m.weight.data, 0)
                m.bias.data.zero_()
net = Net()

# 保存,并加载模型参数(仅保存模型参数)
torch.save(net.state_dict(), 'net_params.pkl')   # 假设训练好了一个模型net
pretrained_dict = torch.load('net_params.pkl')

# 将net的参数全部置0,方便对比
net.zero_param()
net_state_dict = net.state_dict()
print('conv1层的权值为:\n', net_state_dict['conv1.weight'], '\n')

# 通过load_state_dict 加载参数
net.load_state_dict(pretrained_dict)
print('加载之后,conv1层的权值变为:\n', net_state_dict['conv1.weight'])

 6. paso (cierre)

Función: Realice una actualización de peso de un paso.

Supongo que te gusta

Origin blog.csdn.net/zhang2362167998/article/details/128853413
Recomendado
Clasificación