PyTorch 使用CNN实现MNIST手写字体识别

版权声明:转载请注明出处及原文地址。 https://blog.csdn.net/zl1085372438/article/details/86635964

 一个epoch下来,Test ACC: 0.9892

import torch
import torch.nn as nn
import torchvision
from torchvision import datasets,transforms
from torch.autograd import Variable
from matplotlib import pyplot as plt

device = torch.device('cuda')

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
                nn.Conv2d(1,48,kernel_size = 3,padding = 0),  #48,26,26
                nn.BatchNorm2d(48),
                nn.ReLU(),
        )
        self.layer2 = nn.Sequential(
                nn.Conv2d(48,96,kernel_size = 3,padding = 1), #96,26,26
                nn.BatchNorm2d(96),
                nn.ReLU(),
        )
        self.layer3 = nn.Sequential(
                nn.Conv2d(96, 192, kernel_size = 3, padding = 0),  # 192,24,24
                nn.BatchNorm2d(192),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2,stride = 2)  #192,12,12
        )
        self.layer4 = nn.Sequential(
                nn.Conv2d(192, 384, kernel_size = 3, padding = 1),  # 384,12,12
                nn.BatchNorm2d(384),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2, stride = 2)  # 384,6,6
        )
        self.layer5 = nn.Sequential(
                nn.Conv2d(384, 768, kernel_size = 3, padding = 1),  # 768,6,6
                nn.BatchNorm2d(768),
                nn.ReLU(),
                nn.MaxPool2d(kernel_size = 2, stride = 2)  # 768,3,3
        )
        self.fc = nn.Sequential(
                nn.Linear(768 * 3 * 3, 4096),
                nn.BatchNorm1d(4096),
                nn.ReLU(),
                nn.Linear(4096, 1024),
                nn.BatchNorm1d(1024),
                nn.ReLU(),
                nn.Linear(1024, 10)
        )
    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        x = x.reshape(x.shape[0],-1)
        x = self.fc(x)
        return x





model = CNN()
model = model.to(device)
print(model)



model = model.train()

img_transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize(mean = [0.5,0.5,0.5],std = [0.5,0.5,0.5])])
dataset_train = datasets.MNIST(root = './data',transform = img_transform,train = True,download = True)
dataset_test = datasets.MNIST(root = './data',transform = img_transform,train = False,download = True)

train_loader = torch.utils.data.DataLoader(dataset = dataset_train,batch_size=64,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = dataset_test,batch_size=64,shuffle = False)

# images,label = next(iter(train_loader))
# print(images.shape)
# print(label.shape)
# images_example = torchvision.utils.make_grid(images)
# images_example = images_example.numpy().transpose(1,2,0)
# mean = [0.5,0.5,0.5]
# std = [0.5,0.5,0.5]
# images_example = images_example*std + mean
# plt.imshow(images_example)
# plt.show()

def Get_ACC():
    correct = 0
    total_num = len(dataset_test)
    for item in test_loader:
        batch_imgs,batch_labels = item
        batch_imgs = Variable(batch_imgs)
        batch_imgs = batch_imgs.to(device)
        batch_labels = batch_labels.to(device)
        out = model(batch_imgs)
        _,pred = torch.max(out.data,1)
        correct += torch.sum(pred==batch_labels)
        # print(pred)
        # print(batch_labels)
    correct = correct.data.item()
    acc = correct/total_num
    print('correct={},Test ACC:{:.5}'.format(correct,acc))



optimizer = torch.optim.Adam(model.parameters())
loss_f = nn.CrossEntropyLoss()

Get_ACC()
for epoch in range(5):
    print('epoch:{}'.format(epoch))
    cnt = 0
    for item in train_loader:
        batch_imgs ,batch_labels = item
        batch_imgs,batch_labels = Variable(batch_imgs),Variable(batch_labels)
        batch_imgs = batch_imgs.to(device)
        batch_labels = batch_labels.to(device)
        out = model(batch_imgs)
        # print(out.shape)
        loss = loss_f(out,batch_labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if(cnt%100==0):
            print_loss = loss.data.item()
            print('epoch:{},cnt:{},loss:{}'.format(epoch,cnt,print_loss))
        cnt+=1
    Get_ACC()


torch.save(model,'model')

猜你喜欢

转载自blog.csdn.net/zl1085372438/article/details/86635964