3.Entrenar al modelo
1. Obtenga un lote de datos del conjunto de entrenamiento.
2. Pase este lote de datos a la red.
3. Calcular la pérdida.
4. Calcular el gradiente de la función de pérdida con respecto a los pesos de la red neuronal.
5. Actualice los pesos y use el gradiente para reducir la pérdida.
6. Repita los pasos 1 a 5 hasta completar un ciclo.
7. Repita los pasos 1 a 7 durante tantos ciclos como desee para lograr el nivel de precisión deseado.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
torch.set_printoptions(linewidth=120) # Display options for output
print(torch.__version__)
print(torchvision.__version__)
def get_num_correct(preds, labels):
return preds.argmax(dim=1).eq(labels).sum().item()
class Network(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)
self.fc1 = nn.Linear(in_features=12 * 4 * 4, out_features=120)
self.fc2 = nn.Linear(in_features=120, out_features=60)
self.out = nn.Linear(in_features=60, out_features=10)
def forward(self, t):
t = F.relu(self.conv1(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = F.relu(self.conv2(t))
t = F.max_pool2d(t, kernel_size=2, stride=2)
t = t.reshape(-1, 12 * 4 * 4)
t = F.relu(self.fc1(t))
t = F.relu(self.fc2(t))
t = self.out(t)
return t
train_set = torchvision.datasets.FashionMNIST(
root="./data/FashionMNIST"
, train=True
, download=True
, transform=transforms.Compose([
transforms.ToTensor()
])
)
network = Network()
batch_size = 100
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(network.parameters(), lr=0.01) # 优化器
for epoch in range(5):
total_loss = 0
total_correct = 0
for batch in train_loader:
images, labels = batch
preds = network(images)
loss = F.cross_entropy(preds, labels)
optimizer.zero_grad() # 清空梯度
loss.backward() # 计算梯度
optimizer.step() # 更新权重
total_loss += loss.item()
total_correct += get_num_correct(preds, labels)
print("epoch:", epoch, "total_correct:", total_correct, "total_loss:", total_loss, "mean_loss",
total_loss / batch_size)
D:\pytorch\pytorchbasis\venv\Scripts\python.exe D:\pytorch\pytorchbasis\train.py
1.13.1+cu116
0.14.1+cu116
epoch: 0 total_correct: 46634 total_loss: 350.1262176781893 mean_loss 3.5012621767818928
epoch: 1 total_correct: 51356 total_loss: 232.00757797062397 mean_loss 2.3200757797062397
epoch: 2 total_correct: 51899 total_loss: 216.55600735545158 mean_loss 2.165560073554516
epoch: 3 total_correct: 52378 total_loss: 204.50732965767384 mean_loss 2.045073296576738
epoch: 4 total_correct: 52925 total_loss: 192.59698635339737 mean_loss 1.9259698635339737
进程已结束,退出代码0