版权声明:转载请注明出处及原文地址。 https://blog.csdn.net/zl1085372438/article/details/86635964
一个epoch下来,Test ACC: 0.9892
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets,transforms
from torch.autograd import Variable
from matplotlib import pyplot as plt
device = torch.device('cuda')
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1,48,kernel_size = 3,padding = 0), #48,26,26
nn.BatchNorm2d(48),
nn.ReLU(),
)
self.layer2 = nn.Sequential(
nn.Conv2d(48,96,kernel_size = 3,padding = 1), #96,26,26
nn.BatchNorm2d(96),
nn.ReLU(),
)
self.layer3 = nn.Sequential(
nn.Conv2d(96, 192, kernel_size = 3, padding = 0), # 192,24,24
nn.BatchNorm2d(192),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2,stride = 2) #192,12,12
)
self.layer4 = nn.Sequential(
nn.Conv2d(192, 384, kernel_size = 3, padding = 1), # 384,12,12
nn.BatchNorm2d(384),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2, stride = 2) # 384,6,6
)
self.layer5 = nn.Sequential(
nn.Conv2d(384, 768, kernel_size = 3, padding = 1), # 768,6,6
nn.BatchNorm2d(768),
nn.ReLU(),
nn.MaxPool2d(kernel_size = 2, stride = 2) # 768,3,3
)
self.fc = nn.Sequential(
nn.Linear(768 * 3 * 3, 4096),
nn.BatchNorm1d(4096),
nn.ReLU(),
nn.Linear(4096, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 10)
)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.layer5(x)
x = x.reshape(x.shape[0],-1)
x = self.fc(x)
return x
model = CNN()
model = model.to(device)
print(model)
model = model.train()
img_transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean = [0.5,0.5,0.5],std = [0.5,0.5,0.5])])
dataset_train = datasets.MNIST(root = './data',transform = img_transform,train = True,download = True)
dataset_test = datasets.MNIST(root = './data',transform = img_transform,train = False,download = True)
train_loader = torch.utils.data.DataLoader(dataset = dataset_train,batch_size=64,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = dataset_test,batch_size=64,shuffle = False)
# images,label = next(iter(train_loader))
# print(images.shape)
# print(label.shape)
# images_example = torchvision.utils.make_grid(images)
# images_example = images_example.numpy().transpose(1,2,0)
# mean = [0.5,0.5,0.5]
# std = [0.5,0.5,0.5]
# images_example = images_example*std + mean
# plt.imshow(images_example)
# plt.show()
def Get_ACC():
correct = 0
total_num = len(dataset_test)
for item in test_loader:
batch_imgs,batch_labels = item
batch_imgs = Variable(batch_imgs)
batch_imgs = batch_imgs.to(device)
batch_labels = batch_labels.to(device)
out = model(batch_imgs)
_,pred = torch.max(out.data,1)
correct += torch.sum(pred==batch_labels)
# print(pred)
# print(batch_labels)
correct = correct.data.item()
acc = correct/total_num
print('correct={},Test ACC:{:.5}'.format(correct,acc))
optimizer = torch.optim.Adam(model.parameters())
loss_f = nn.CrossEntropyLoss()
Get_ACC()
for epoch in range(5):
print('epoch:{}'.format(epoch))
cnt = 0
for item in train_loader:
batch_imgs ,batch_labels = item
batch_imgs,batch_labels = Variable(batch_imgs),Variable(batch_labels)
batch_imgs = batch_imgs.to(device)
batch_labels = batch_labels.to(device)
out = model(batch_imgs)
# print(out.shape)
loss = loss_f(out,batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if(cnt%100==0):
print_loss = loss.data.item()
print('epoch:{},cnt:{},loss:{}'.format(epoch,cnt,print_loss))
cnt+=1
Get_ACC()
torch.save(model,'model')