[Práctica de PyTorch] Regresión lineal de PyTorch de entrada de base cero [explicación del código línea por línea]

prefacio

Esta sección contiene un pequeño ejemplo de cómo comenzar con la antorcha

La explicación en video del código línea por línea está en: https://www.bilibili.com/video/BV1nS4y1u76S?spm_id_from=333.999.0.0

Principalmente sobre ejemplos de regresión lineal.

  • Crea tus propios datos

  • Construir un modelo de regresión lineal

  • completar el proceso de formación

  • Pantalla de dibujo

Modelo lineal

94a77f1d8fdaa1c1053e4965aa4c275a.pngdonde k es el peso y b es el término de sesgo.

En general, el modelo lineal es para ajustar el k y b de los cuales son en realidad w y sesgo

como aquí

33c9ce863f6896e9431a164a90b70152.png

código como un todo

Datos de simulación

Agregue ruido blanco gaussiano (un grupo de números aleatorios que se ajustan a una distribución normal con una media de 0 y una varianza de 1) y establezca x en 512 puntos, es decir, el número de muestras es 512

edca9992d8ea27bd93beb8714b158d8e.png

Modelo lineal

Debido a que cada valor de la entrada y cada valor de la salida es en realidad una dimensión de 1, feature_num=1, el modelo lineal es

class LinearModel(nn.Module):
    def __init__(self, in_fea, out_fea):
        super(LinearModel, self).__init__()
        self.out = nn.Linear(in_fea, out_fea)
    def forward(self, x):
        x = self.out(x)
        return x
7b87ac96e286b0cb41ffb446066491a9.png

Definir la función de pérdida y el optimizador

optimizer = torch.optim.SGD(model.parameters(), lr=0.02)

loss_func = nn.MSELoss()

Cambiar la dimensión de los datos a la dimensión de entrada del modelo

62db9f00bd1215d9e37b1dc2a7728a86.pngAgregar una dimensión a la dimensión de la característica

entrenamiento y visualización

la rutina es

  1. razonamiento directo

  2. pérdida

  3. gradiente claro

  4. retropropagación

  5. actualizar pesos

plt.ion()
for step in range(200):
    prediction = model(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step%10 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
        plt.xlim(0,1.1)
        plt.ylim(0, 20)
        [w, b] = model.parameters()
        plt.text(0, 0.5, 'loss=%.4f, k=%.2f, b=%2f'%(loss.item(), w.item(), b.item() ),fontdict={'size': 20, 'color':  'red'})
        plt.pause(0.5)
plt.ioff()
plt.show()

Entre ellos, ion es para abrir el modo interactivo e ioff es para cerrar el modo interactivo, y puede dibujar dinámicamente la imagen para ver los cambios.

Observe dinámicamente la transformación de la línea ajustada

1d0c9107f38b4beb288412bf807596cc.png 0c0a9120f2fd09e7fa31dd8c9c4878f0.png 89ab799498523a521e86d7662f67cae9.png

Lectura recomendada:

Mi intercambio de reclutamiento escolar por Internet de 2022

Mi Resumen 2021

Hablando de la diferencia entre la publicación de algoritmos y la publicación de desarrollo

Resumen de salarios de investigación y desarrollo de reclutamiento de escuelas de Internet

Para series de tiempo, todo lo que puedes hacer.

¿Qué es el problema de la secuencia espacio-temporal? ¿Qué modelos se utilizan principalmente para tales problemas? ¿Cuáles son las principales aplicaciones?

Número público: coche caracol AI

Mantente humilde, mantente disciplinado, mantente progresista

91180d846d68ee78da035e9b1ad376da.png

Envíe [Snail] para obtener una copia del "Proyecto práctico de IA" (AI Snail Car)

Envíe [1222] para obtener una buena nota de cepillado de leetcode

Envíe [AI Four Classics] para obtener cuatro libros electrónicos clásicos de AI

Supongo que te gusta

Origin blog.csdn.net/qq_33431368/article/details/123516036
Recomendado
Clasificación