pytorch cnn

import torch
import torch.nn as nn
import torchvision
import matplotlib.pyplot as plt
import torch.utils.data as Data
import cv2 as cv

EPOCH = 1
BATCH_SIZE = 50
LR = 0.001
DOWNLOAD_MNIST = False


train_data = torchvision.datasets.MNIST(
    root='./mnist/',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST,
)

train_loader = Data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.   # shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_y = test_data.test_labels[:2000]



class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.out = nn.Linear(32*7*7, 10)
    #  有batch时, 传入张量为(batch, 通道数, 长, 宽) ,因此起始输入图片要加入一个维度
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        #  卷积核输出的是多维形状(长*宽*深度*批)
        #  ()压缩为
        x = x.view(x.size(0), -1)
        return self.out(x)



def PicGet(picpath):
    img = cv.imread(picpath)
    grayImg = cv.cvtColor(img, cv.COLOR_BGR2GRAY)
    micro = cv.resize(grayImg, (28, 28), interpolation=cv.INTER_AREA)
    ImgIn = torch.Tensor(micro).unsqueeze(0)
    return ImgIn.unsqueeze(0)

def getNum(ins):
    return torch.max(ins, 1)[1].squeeze().numpy()



cnn = CNN()
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)
loss_func = nn.CrossEntropyLoss()
accHis = []
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):
        output = cnn(b_x)
        loss = loss_func(output, b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step % 50 == 0:
            test_output = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze().numpy()
            accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
            accHis.append(accuracy)


print('ok')
#  torch.save(cnn.state_dict(), 'net_params.pkl')
ImgNum = PicGet('test.jpg')
outs = cnn(ImgNum)
anws = getNum(outs)
print(anws)
plt.plot(accHis)
plt.show()






猜你喜欢

转载自blog.csdn.net/futangxiang4793/article/details/82731508