[Pytorch] Proyecto de programación de predicción de imágenes de CNN-Entrena el modelo


3.Entrenar al modelo

1. Obtenga un lote de datos del conjunto de entrenamiento.
2. Pase este lote de datos a la red.
3. Calcular la pérdida.
4. Calcular el gradiente de la función de pérdida con respecto a los pesos de la red neuronal.
5. Actualice los pesos y use el gradiente para reducir la pérdida.
6. Repita los pasos 1 a 5 hasta completar un ciclo.
7. Repita los pasos 1 a 7 durante tantos ciclos como desee para lograr el nivel de precisión deseado.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

torch.set_printoptions(linewidth=120)  # Display options for output

print(torch.__version__)
print(torchvision.__version__)


def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

        self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)

    def forward(self, t):
        t = F.relu(self.conv1(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = F.relu(self.conv2(t))
        t = F.max_pool2d(t, kernel_size=2, stride=2)

        t = t.reshape(-1, 12 * 4 * 4)
        t = F.relu(self.fc1(t))

        t = F.relu(self.fc2(t))

        t = self.out(t)

        return t


train_set = torchvision.datasets.FashionMNIST(
    root="./data/FashionMNIST"
    , train=True
    , download=True
    , transform=transforms.Compose([
        transforms.ToTensor()
    ])
)

network = Network()

batch_size = 100
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(network.parameters(), lr=0.01)  # 优化器

for epoch in range(5):

    total_loss = 0
    total_correct = 0

    for batch in train_loader:
        images, labels = batch

        preds = network(images)
        loss = F.cross_entropy(preds, labels)

        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 计算梯度
        optimizer.step()  # 更新权重

        total_loss += loss.item()
        total_correct += get_num_correct(preds, labels)

    print("epoch:", epoch, "total_correct:", total_correct, "total_loss:", total_loss, "mean_loss",
          total_loss / batch_size)
D:\pytorch\pytorchbasis\venv\Scripts\python.exe D:\pytorch\pytorchbasis\train.py 
1.13.1+cu116
0.14.1+cu116
epoch: 0 total_correct: 46634 total_loss: 350.1262176781893 mean_loss 3.5012621767818928
epoch: 1 total_correct: 51356 total_loss: 232.00757797062397 mean_loss 2.3200757797062397
epoch: 2 total_correct: 51899 total_loss: 216.55600735545158 mean_loss 2.165560073554516
epoch: 3 total_correct: 52378 total_loss: 204.50732965767384 mean_loss 2.045073296576738
epoch: 4 total_correct: 52925 total_loss: 192.59698635339737 mean_loss 1.9259698635339737

进程已结束,退出代码0

Supongo que te gusta

Origin blog.csdn.net/weixin_66896881/article/details/128663153
Recomendado
Clasificación