Pytorch implementation --- handwritten digit recognition

        This article is a study excerpt and note recorded by the blogger for personal study, research or appreciation when he is studying in artificial intelligence and other fields. It is based on the blogger’s understanding of artificial intelligence and other fields. If there is any inappropriateness or infringement, We will correct it immediately after pointing it out. We hope for your understanding. Article classification in Pytorch :

       Pytorch (1) --- " pytorch implementation --- handwritten digit recognition"

Pytorch implementation --- handwritten digit recognition

Table of contents

1.Project introduction

2. Implementation method

3. Program code

4. Running results


1.Project introduction

        Use pytorch to realize handwritten digit recognition. It is a very simple small project. The environment is set up and it can be completed in one run.


2. Implementation method

2.1 Method 1        

 Installation library:

pip install numpy torch torchvision matplotlib

 run:

python test.py

The MNIST data set will be downloaded during the first run, please keep the network open.

2.2 Method 2

        If you use pycharm and have installed the pytorch environment, just run the following code directly in the pytorch environment.


3. Program code

"""手写数字识别项目
    时间:2023.11.6
    环境:pytorch
    作者:Rainbook
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

class Net(torch.nn.Module):  # 定义一个Net类,神经网络的主体
    def __init__(self):  # 全连接层,四个
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)  # 输入层输入28*28,输出64
        self.fc2 = torch.nn.Linear(64, 64)  # 中间层,输入64,输出64
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)  # 中间层(隐藏层)的最后一层,输出10个特征值
    
    def forward(self, x):  # 前向传播过程
        # self.fc1(x)全连接线性计算,再套上一个激活函数torch.nn.functional.relu()
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        # 最后一层进行softmax归一化,log_softmax是为了提高计算稳定性,在softmax后面套上了一个对数运算
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])  # 定义数据转换类型tensor,多维数组(张量)
    """下载MNIST数据集,
        "":当前位置
        is_train:判断是训练集还是测试集;
        batch_size:一个批次包含15张图片;
        shuffle:数据随机打乱的
    """
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)  # 数据加载器


def evaluate(test_data, net):  # 用来评估神经网络
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))  # 计算神经网络的预测值
            for i, output in enumerate(outputs):  # 对每个批次的预测值进行比较,累加正确预测的数量
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total  # 返回正确率


def main():
    # 导入训练集和测试集
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()  # 初始化神经网络

    # 打印初始网络的正确率,应当是10%附近。手写数字有十种结果,随机猜的正确率就是1/10
    print("initial accuracy:", evaluate(test_data, net))
    """训练神经网络
    pytorch的固定写法
    """
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(5):  # 需要在一个数据集上反复训练神经网络,epoch网络轮次,提高数据集的利用率
        for (x, y) in train_data:
            net.zero_grad()  # 初始化
            output = net.forward(x.view(-1, 28*28))  # 正向传播
            # 计算差值,nll_loss对数损失函数,为了匹配log_softmax的log运算
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()  # 反向误差传播
            optimizer.step()  # 优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印当前网络的正确率

    """测试神经网络
        训练完成后,随机抽取3张图片进行测试
    """
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))  # 测试结果
        plt.figure(n)  # 画出图像
        plt.imshow(x[0].view(28, 28))  # 像素大小28*28
        plt.title("prediction: " + str(int(predict)))  # figure的标题
    plt.show()


if __name__ == "__main__":
    main()

4. Running results

4.1 Accuracy rate

4.2 Test results

        Reference source: Station B

        If there is any inappropriateness or inaccuracy in the article, I hope you will understand and point it out. Since some texts, pictures, etc. come from the Internet, the true source cannot be verified. If there is any related dispute, please contact the blogger to delete it. If there are any errors, questions or infringements, please leave a comment to contact the author, or follow the VX public account: Rain21321 to contact the author.

Guess you like

Origin blog.csdn.net/qq_51399582/article/details/134246314