The title uses torch to realize handwritten digit recognition
import os
import torch
import torchvision
from torchvision.transforms import Compose,ToTensor,Normalize
from torch.utils.data import DataLoader
from torch import nn, optim
import torch.nn.functional as F
import numpy as np
TRAIN_BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000
def get_dataloader(train=True,batch_size=TRAIN_BATCH_SIZE):
transform_fn = Compose([
ToTensor(),
Normalize(
(0.1307,), (0.3081,))
])
dataset = torchvision.datasets.MNIST(root="./data",train=train,download=True,transform=transform_fn)
return DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)
class ImageNet(nn.Module):
def __init__(self):
super(ImageNet, self).__init__()
self.fc1 = nn.Linear(1*28*28, 28)
self.fc2 = nn.Linear(28, 10)
def forward(self, data):
features = data.view(data.size(0), 1*28*28)
features = self.fc1(features)
features = F.relu(features)
out = self.fc2(features)
return F.log_softmax(out, dim=-1)
model = ImageNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
if os.path.exists('./models/model.pkl'):
model.load_state_dict(torch.load('./models/model.pkl'))
optimizer.load_state_dict(torch.load('./models/optimizer.pkl'))
def train(epoch):
mode = True
model.train(mode=mode)
train_dataloader = get_dataloader(train=mode)
for idx, (data, target) in enumerate(train_dataloader):
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output,target)
loss.backward()
optimizer.step()
if idx % 10 == 0:
print('第%d轮次,损失值为%f' % (epoch, loss.item()))
if idx % 100 == 0:
torch.save(model.state_dict(), './models/model.pkl')
torch.save(optimizer.state_dict(), './models/optimizer.pkl')
def test():
test_loss = []
correct = []
model.eval()
test_dataloader = get_dataloader(train=False, batch_size=TEST_BATCH_SIZE)
with torch.no_grad():
for data, target in test_dataloader:
output = model(data)
test_loss.append(F.nll_loss(output, target))
pred = output.data.max(dim=1)[1]
correct.append(pred.eq(target).float().mean())
print('模型损失%f,平均准确率%f' % (np.mean(test_loss), np.mean(correct)))
if __name__ == '__main__':
for i in range(5):
train(i)
test()