import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import torch.utils.data as Data
import cv2 as cv
EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False
train_data = torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=DOWNLOAD_MNIST,
)
train_loader = Data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=16,
kernel_size=5,
stride=1,
padding=2,
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.out = nn.Linear(32*7*7, 10)
# 有batch时, 传入张量为(batch, 通道数, 长, 宽) ,因此起始输入图片要加入一个维度
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
# 卷积核输出的是多维形状(长*宽*深度*批)
# ()压缩为
x = x.view(x.size(0), -1)
return self.out(x)
def PicGet(picpath):
img = cv.imread(picpath)
grayImg = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
micro = cv.resize(grayImg, (28, 28), interpolation=cv.INTER_AREA)
ImgIn = torch.Tensor(micro).unsqueeze(0)
return ImgIn.unsqueeze(0)
def getNum(ins):
return torch.max(ins, 1)[1].squeeze().numpy()
cnn = CNN()
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
accHis = []
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 50 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
accHis.append(accuracy)
print('ok')
# torch.save(cnn.state_dict(), 'net_params.pkl')
ImgNum = PicGet('test.jpg')
outs = cnn(ImgNum)
anws = getNum(outs)
print(anws)
plt.plot(accHis)
plt.show()
pytorch cnn
猜你喜欢
转载自blog.csdn.net/futangxiang4793/article/details/82731508
今日推荐
周排行