最基础的神经网络入门案例——MNIST手写数字数据集识别

一、全连接网络实现

1、实现说明

(1) 输入和输出

输入:28*28手写数字图片
输出:判定该图像对应的数字

(2) 网络结构

在这里插入图片描述

2、代码实现

import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from pathlib import Path
import requests
import pickle
import gzip
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim
from time import time


class MNIST_NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(784, 128)
        self.hidden2 = nn.Linear(128, 512)
        # self.hidden3 = nn.Linear(256, 512)
        self.out = nn.Linear(512, 10)
        self.dropout = nn.Dropout(0.5)
        # 97.64%  3个隐层
        # 97.79%  2个隐层 第2隐层256神经元
        # 97.82%  2个隐层 第2隐层512神经元

    def forward(self, x):
        x = F.relu(self.hidden1(x))
        x = self.dropout(x)
        x = F.relu(self.hidden2(x))
        # x = self.dropout(x)
        # x = F.relu(self.hidden3(x))
        x = self.dropout(x)
        return self.out(x)


#  ############################### 查看网络的参数
# print(net)
# # 打印权重参数
# for name, parameter in net.named_parameters():
#     print(name, parameter)

def get_model():
    model = MNIST_NN()
    return model, optim.Adam(model.parameters(), lr=0.001)  # Adam


def get_data(bs):
    DATA_PATH = Path("data")
    PATH = DATA_PATH / "mnist"

    PATH.mkdir(parents=True, exist_ok=True)
    URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"
    FILENAME = "mnist.pkl.gz"

    if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)
    with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
    x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
    train_ds = TensorDataset(x_train, y_train)
    valid_ds = TensorDataset(x_valid, y_valid)
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2)
    )


def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()

    return loss.item(), len(xb)


# 定义训练函数
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        start_time = time()
        model.train()  # 指定model的模式,一般训练模型时加上model.train()就会正常使用Batch Normalization和Dropout
        # 更新权重和偏置
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)

        model.eval()  # 指定model的模式,一般训练模型时加上model.eval()就不会正常使用Batch Normalization和Dropout
        with torch.no_grad():
            losses, nums = zip(*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl])
            val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
            end_time = time()
            print("当前step: " + str(step) + ",验证集平均损失:" + str(val_loss), ", 消耗时间:", end_time - start_time)
    correct = 0
    total = 0
    for xb, yb in valid_dl:
        outputs = model(xb)
        _, predicted = torch.max(outputs.data, 1)
        total += yb.size(0)
        correct += (yb == predicted).sum().item()
    print("Acc: ", 100 * correct / total, "%")


if __name__ == '__main__':
    bs = 64  # batch_size = 64
    epoch = 30
    train_dl, valid_dl = get_data(bs)
    loss_func = F.cross_entropy
    model, opt = get_model()
    fit(epoch, model, loss_func, opt, train_dl, valid_dl)

3、运行结果

C:\Users\Administrator\.conda\envs\torzml\python.exe D:/Project/PythonProject/LSTM_text/others/MNISTReco.py
当前step: 0,验证集平均损失:0.1683511969923973 , 消耗时间: 3.590179204940796
当前step: 1,验证集平均损失:0.1312192772962153 , 消耗时间: 3.058769941329956
当前step: 2,验证集平均损失:0.11431313703283667 , 消耗时间: 2.9534740447998047
当前step: 3,验证集平均损失:0.10498983032917604 , 消耗时间: 3.0124051570892334
当前step: 4,验证集平均损失:0.09412320185815916 , 消耗时间: 3.009488344192505
当前step: 5,验证集平均损失:0.093143627073057 , 消耗时间: 2.890739917755127
当前step: 6,验证集平均损失:0.09063829299034551 , 消耗时间: 3.1035313606262207
当前step: 7,验证集平均损失:0.08887825601268559 , 消耗时间: 3.6453356742858887
当前step: 8,验证集平均损失:0.08634029937214219 , 消耗时间: 3.641497850418091
当前step: 9,验证集平均损失:0.08860128401129042 , 消耗时间: 4.121753215789795
当前step: 10,验证集平均损失:0.08582983395410701 , 消耗时间: 3.940110445022583
当前step: 11,验证集平均损失:0.08012033958600369 , 消耗时间: 3.7954063415527344
当前step: 12,验证集平均损失:0.08083907972197048 , 消耗时间: 3.8298981189727783
当前step: 13,验证集平均损失:0.07870770492306911 , 消耗时间: 3.784552812576294
当前step: 14,验证集平均损失:0.0787701717220014 , 消耗时间: 3.679483652114868
当前step: 15,验证集平均损失:0.07930033117127605 , 消耗时间: 3.6713767051696777
当前step: 16,验证集平均损失:0.07984851049119607 , 消耗时间: 3.6588127613067627
当前step: 17,验证集平均损失:0.07337796090731863 , 消耗时间: 3.672180414199829
当前step: 18,验证集平均损失:0.07625965428813361 , 消耗时间: 3.6567420959472656
当前step: 19,验证集平均损失:0.07491345010597725 , 消耗时间: 3.6410393714904785
当前step: 20,验证集平均损失:0.07532293595063966 , 消耗时间: 3.6485989093780518
当前step: 21,验证集平均损失:0.07646674913240131 , 消耗时间: 3.6318893432617188
当前step: 22,验证集平均损失:0.07439087963479106 , 消耗时间: 3.6764001846313477
当前step: 23,验证集平均损失:0.07468655214742757 , 消耗时间: 3.6644670963287354
当前step: 24,验证集平均损失:0.07627506786901504 , 消耗时间: 3.6473772525787354
当前step: 25,验证集平均损失:0.07737481498552952 , 消耗时间: 3.6624276638031006
当前step: 26,验证集平均损失:0.07628958516074927 , 消耗时间: 3.6760411262512207
当前step: 27,验证集平均损失:0.07489939673700718 , 消耗时间: 3.670935869216919
当前step: 28,验证集平均损失:0.07496044216285809 , 消耗时间: 3.631747007369995
当前step: 29,验证集平均损失:0.07229171380188491 , 消耗时间: 3.705427646636963
Acc:  97.94 %

进程已结束,退出代码 0

二、卷积神经网络实现

1、实现说明

(1)输入和输出

输入:28*28手写数字图片
输出:判定该图像对应的数字

(2)网络结构

2、代码实现

import gzip
import pickle
from pathlib import Path

import requests
from torch import nn
from torch import optim
import torch
import time

# 网络结构
from torch.utils.data import TensorDataset, DataLoader


class MnistNet(nn.Module):
    def __init__(self):
        super(MnistNet, self).__init__()
        self.conv1 = nn.Sequential(  # (b, 1, 28, 28)
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),  # (b, 16, 28, 28)
            nn.ReLU(),
            nn.MaxPool2d(2)  # (b, 16, 14, 14)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2),  # (b, 32, 14, 14)
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2),  # (b, 32, 14, 14)
            nn.ReLU(),
            nn.MaxPool2d(2)  # (b, 32, 7, 7)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, stride=1, padding=2),  # (b, 64, 7, 7)
            nn.ReLU()
        )
        self.out = nn.Linear(64*7*7, 10)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)
        output = self.out(x)
        return output


def get_data(bs):
    DATA_PATH = Path("data")
    PATH = DATA_PATH / "mnist"

    PATH.mkdir(parents=True, exist_ok=True)
    URL = "https://storage.googleapis.com/cvdf-datasets/mnist/"
    FILENAME = "mnist.pkl.gz"

    if not (PATH / FILENAME).exists():
        content = requests.get(URL + FILENAME).content
        (PATH / FILENAME).open("wb").write(content)
    with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
        ((x_train, y_train), (x_valid, y_valid), _) = pickle.load(f, encoding="latin-1")
    x_train, y_train, x_valid, y_valid = map(torch.tensor, (x_train, y_train, x_valid, y_valid))
    device = "cuda" if torch.cuda.is_available() else "cpu"
    x_train, y_train, x_valid, y_valid = x_train.to(device), y_train.to(device), x_valid.to(device), y_valid.to(device)
    train_ds = TensorDataset(x_train, y_train)
    valid_ds = TensorDataset(x_valid, y_valid)
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2)
    )


def get_model():
    model = MnistNet()
    return model, optim.Adam(model.parameters(), lr=0.001)


def accuracy(predictions, labels):
    pred = torch.max(predictions, 1)[1]
    rights = pred.eq(labels.data.view_as(pred)).sum()
    return rights, len(labels)


def fit(model, optimizer, bs, epochs, loss_func):
    train_dl, valid_dl = get_data(bs)
    for epoch in range(epochs):
        train_rights = []
        for batch_index, (xb, yb) in enumerate(train_dl):
            model.train()
            output = model(xb)
            loss = loss_func(output, yb)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            rights = accuracy(output, yb)
            train_rights.append(rights)
            # print(batch_index)

            if batch_index % 100 == 0:
                model.eval()
                valid_rights = []
                for xb, yb in valid_dl:
                    rights = accuracy(model(xb), yb)
                    valid_rights.append(rights)

                # 准确率计算
                train_rate = (sum([tup[0] for tup in train_rights]), sum([tup[1] for tup in train_rights]))
                valid_rate = (sum([tup[0] for tup in valid_rights]), sum([tup[1] for tup in valid_rights]))

                print('当前epoch: {} [{} / {} ({:.2f}%)]\t损失: {:.6f} \t 训练集准确率: {:.2f}%\t 测试集准确率: {:.2f}%'.format(
                    epoch,
                    batch_index * bs,
                    len(train_dl.dataset),
                    100. * batch_index / len(train_dl),
                    loss.data,
                    100. * train_rate[0].cpu().numpy() / train_rate[1],
                    100. * valid_rate[0].cpu().numpy() / valid_rate[1]
                ))
                train_rights = []


if __name__ == '__main__':
    bs = 64
    epochs = 10
    loss_func = nn.CrossEntropyLoss()
    model, optimizer = get_model()
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)
    fit(model, optimizer, bs, epochs, loss_func)

3、运行结果

C:\Users\Administrator\.conda\envs\torzml\python.exe D:/Project/PythonProject/LSTM_text/others/Mnist_Conv.py
当前epoch: 0 [0 / 50000 (0.00%)]	损失: 2.301290 	 训练集准确率: 7.81%	 测试集准确率: 11.80%
当前epoch: 0 [6400 / 50000 (12.79%)]	损失: 0.341838 	 训练集准确率: 76.25%	 测试集准确率: 94.57%
当前epoch: 0 [12800 / 50000 (25.58%)]	损失: 0.088409 	 训练集准确率: 94.03%	 测试集准确率: 96.25%
当前epoch: 0 [19200 / 50000 (38.36%)]	损失: 0.208718 	 训练集准确率: 96.33%	 测试集准确率: 97.55%
当前epoch: 0 [25600 / 50000 (51.15%)]	损失: 0.123537 	 训练集准确率: 97.09%	 测试集准确率: 97.80%
当前epoch: 0 [32000 / 50000 (63.94%)]	损失: 0.018433 	 训练集准确率: 97.47%	 测试集准确率: 97.76%
当前epoch: 0 [38400 / 50000 (76.73%)]	损失: 0.097949 	 训练集准确率: 97.95%	 测试集准确率: 98.00%
当前epoch: 0 [44800 / 50000 (89.51%)]	损失: 0.070966 	 训练集准确率: 97.72%	 测试集准确率: 98.37%
当前epoch: 1 [0 / 50000 (0.00%)]	损失: 0.123680 	 训练集准确率: 98.44%	 测试集准确率: 98.05%
当前epoch: 1 [6400 / 50000 (12.79%)]	损失: 0.100362 	 训练集准确率: 98.19%	 测试集准确率: 98.56%
当前epoch: 1 [12800 / 50000 (25.58%)]	损失: 0.035721 	 训练集准确率: 98.30%	 测试集准确率: 98.46%
当前epoch: 1 [19200 / 50000 (38.36%)]	损失: 0.007506 	 训练集准确率: 98.56%	 测试集准确率: 98.70%
当前epoch: 1 [25600 / 50000 (51.15%)]	损失: 0.016346 	 训练集准确率: 98.41%	 测试集准确率: 98.50%
当前epoch: 1 [32000 / 50000 (63.94%)]	损失: 0.035862 	 训练集准确率: 98.47%	 测试集准确率: 98.66%
当前epoch: 1 [38400 / 50000 (76.73%)]	损失: 0.050021 	 训练集准确率: 98.56%	 测试集准确率: 98.45%
当前epoch: 1 [44800 / 50000 (89.51%)]	损失: 0.013187 	 训练集准确率: 98.52%	 测试集准确率: 98.41%
当前epoch: 2 [0 / 50000 (0.00%)]	损失: 0.020915 	 训练集准确率: 100.00%	 测试集准确率: 98.88%
当前epoch: 2 [6400 / 50000 (12.79%)]	损失: 0.091681 	 训练集准确率: 99.33%	 测试集准确率: 98.56%
当前epoch: 2 [12800 / 50000 (25.58%)]	损失: 0.034122 	 训练集准确率: 98.95%	 测试集准确率: 98.88%
当前epoch: 2 [19200 / 50000 (38.36%)]	损失: 0.074170 	 训练集准确率: 98.78%	 测试集准确率: 98.63%
当前epoch: 2 [25600 / 50000 (51.15%)]	损失: 0.018684 	 训练集准确率: 98.92%	 测试集准确率: 98.45%
当前epoch: 2 [32000 / 50000 (63.94%)]	损失: 0.007384 	 训练集准确率: 98.80%	 测试集准确率: 98.89%
当前epoch: 2 [38400 / 50000 (76.73%)]	损失: 0.015758 	 训练集准确率: 99.12%	 测试集准确率: 98.95%
当前epoch: 2 [44800 / 50000 (89.51%)]	损失: 0.014709 	 训练集准确率: 99.12%	 测试集准确率: 98.70%
当前epoch: 3 [0 / 50000 (0.00%)]	损失: 0.001915 	 训练集准确率: 100.00%	 测试集准确率: 98.89%
当前epoch: 3 [6400 / 50000 (12.79%)]	损失: 0.005979 	 训练集准确率: 99.30%	 测试集准确率: 98.58%
当前epoch: 3 [12800 / 50000 (25.58%)]	损失: 0.007681 	 训练集准确率: 98.98%	 测试集准确率: 98.94%
当前epoch: 3 [19200 / 50000 (38.36%)]	损失: 0.007919 	 训练集准确率: 99.19%	 测试集准确率: 99.07%
当前epoch: 3 [25600 / 50000 (51.15%)]	损失: 0.060704 	 训练集准确率: 99.16%	 测试集准确率: 98.71%
当前epoch: 3 [32000 / 50000 (63.94%)]	损失: 0.055418 	 训练集准确率: 99.14%	 测试集准确率: 98.90%
当前epoch: 3 [38400 / 50000 (76.73%)]	损失: 0.033441 	 训练集准确率: 99.12%	 测试集准确率: 98.90%
当前epoch: 3 [44800 / 50000 (89.51%)]	损失: 0.028578 	 训练集准确率: 99.03%	 测试集准确率: 98.94%
当前epoch: 4 [0 / 50000 (0.00%)]	损失: 0.016460 	 训练集准确率: 98.44%	 测试集准确率: 98.89%
当前epoch: 4 [6400 / 50000 (12.79%)]	损失: 0.031492 	 训练集准确率: 99.39%	 测试集准确率: 99.08%
当前epoch: 4 [12800 / 50000 (25.58%)]	损失: 0.023526 	 训练集准确率: 99.44%	 测试集准确率: 98.76%
当前epoch: 4 [19200 / 50000 (38.36%)]	损失: 0.004695 	 训练集准确率: 99.34%	 测试集准确率: 98.94%
当前epoch: 4 [25600 / 50000 (51.15%)]	损失: 0.000252 	 训练集准确率: 99.34%	 测试集准确率: 99.01%
当前epoch: 4 [32000 / 50000 (63.94%)]	损失: 0.020111 	 训练集准确率: 99.27%	 测试集准确率: 99.08%
当前epoch: 4 [38400 / 50000 (76.73%)]	损失: 0.024171 	 训练集准确率: 99.52%	 测试集准确率: 98.94%
当前epoch: 4 [44800 / 50000 (89.51%)]	损失: 0.014028 	 训练集准确率: 99.16%	 测试集准确率: 99.04%
当前epoch: 5 [0 / 50000 (0.00%)]	损失: 0.000271 	 训练集准确率: 100.00%	 测试集准确率: 99.24%
当前epoch: 5 [6400 / 50000 (12.79%)]	损失: 0.017214 	 训练集准确率: 99.64%	 测试集准确率: 99.15%
当前epoch: 5 [12800 / 50000 (25.58%)]	损失: 0.000260 	 训练集准确率: 99.53%	 测试集准确率: 99.02%
进程已结束,退出代码 -1

猜你喜欢

转载自blog.csdn.net/sdbyp/article/details/131031803
今日推荐