Beginner Learning: Handwritten Digit Recognition Using Convolutional Neural Networks

Introduction:

In this blog post, we'll take a deep dive into how to use Convolutional Neural Networks (CNNs) for handwritten digit recognition. This is a very basic and widely applied problem in machine learning. Understanding and mastering its solution is very helpful for getting started with deep learning.

We'll start by introducing the basics of convolutional neural networks, then detail how to build and train a simple CNN model, and finally we'll show how to use this model to recognize handwritten digits. In this process, we will explain the principle and code of each step in detail, so that readers can fully understand and implement it by themselves.

Whether you are a novice who is new to machine learning, or a researcher who has a certain foundation and wants to further understand CNN, we believe that you can gain new knowledge and inspiration in this article. Let's start this journey of exploration together and feel the power of convolutional neural network in handwritten digit recognition!

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. 
# This directory will be recovered automatically after resetting environment. 
!ls /home/aistudio/data

In [ ]

# 查看工作区文件, 该目录下的变更将会持久保存. 请及时清理不必要的文件, 避免加载过慢.
# View personal work directory. 
# All changes under this directory will be kept even after reset. 
# Please clean unnecessary files in time to speed up environment loading. 
!ls /home/aistudio/work

In [ ]

# # 如果需要进行持久化安装, 需要使用持久化路径, 如下方代码示例:
# # If a persistence installation is required, 
# # you need to use the persistence path as the following: 
# !mkdir /home/aistudio/external-libraries
# !pip install beautifulsoup4 -t /home/aistudio/external-libraries
mkdir: cannot create directory ‘/home/aistudio/external-libraries’: File exists
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting beautifulsoup4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl (142 kB)
     143.0/143.0 kB 7.6 MB/s eta 0: 00:00
Collecting soupsieve>1.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/49/37/673d6490efc51ec46d198c75903d99de59baffdd47aea3d071b80a9e4e89/soupsieve-2.4.1-py3-none-any.whl (36 kB)
Installing collected packages: soupsieve, beautifulsoup4
Successfully installed beautifulsoup4-4.12.2 soupsieve-2.4.1
WARNING: Target directory /home/aistudio/external-libraries/bs4 already exists. Specify --upgrade to force replacement.
WARNING: Target directory /home/aistudio/external-libraries/soupsieve-2.4.1.dist-info already exists. Specify --upgrade to force replacement.
WARNING: Target directory /home/aistudio/external-libraries/beautifulsoup4-4.12.2.dist-info already exists. Specify --upgrade to force replacement.
WARNING: Target directory /home/aistudio/external-libraries/soupsieve already exists. Specify --upgrade to force replacement.

[notice] A new release of pip available: 22.1.2 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

In [ ]

# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: 
# Also add the following code, 
# so that every time the environment (kernel) starts, 
# just run the following code: 
import sys 
sys.path.append('/home/aistudio/external-libraries')

In [ ]

import paddle
import paddle.nn as nn
import paddle.optimizer as optim
import paddle.nn.functional as F
from paddle.vision.datasets import MNIST
import paddle.vision.transforms as transforms
import paddle.vision.models as models
import matplotlib.pyplot as plt
import numpy as np
# paddlepaddle的导入
'''
PaddlePaddle: import paddle ; import paddle.vision.transforms as transforms...
PyTorch: import torch ; import torchvision.transforms as transforms...
'''
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

In [ ]

image_size = 28
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.001

train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transforms.ToTensor())
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transforms.ToTensor())
'''
PaddlePaddle:使用paddle.vision.datasets.MNIST加载MNIST数据集。
PyTorch:使用torchvision.datasets.MNIST加载MNIST数据集。
'''
# Generate random indices for validation and test sets
permutes = np.random.permutation(range(len(train_dataset)))
indices_val = permutes[:5000]
indices_test = permutes[5000:]

# Create Subset datasets for validation and test sets
val_dataset = paddle.io.Subset(train_dataset, indices_val)
test_dataset = paddle.io.Subset(train_dataset, indices_test)
'''
PaddlePaddle:使用paddle.io.Subset创建子集数据集。
PyTorch:
sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test)
'''
# Create data loaders for train, validation, and test sets
train_loader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = paddle.io.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
'''
PaddlePaddle:使用paddle.io.DataLoader创建
PyTorch:使用torch.utils.data.DataLoader创建
'''

In [ ]


depth = [4, 8]
class ConvNet(nn.Layer):
    '''
    PaddlePaddle:使用paddle.nn.Layer作为基类
    PyTorch:使用torch.nn.Module作为基类
    '''
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2D(1, 4, 5, padding=2)
        self.pool = nn.MaxPool2D(2, 2)
        self.conv2 = nn.Conv2D(depth[0], depth[1], 5, padding=2)
        self.fc1 = nn.Linear(image_size // 4 * image_size // 4 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)
        '''
        PaddlePaddle:使用paddle.nn模块下的层定义,如paddle.nn.Conv2D
        PyTorch:使用torch.nn模块下的层定义,如torch.nn.Conv2d
        '''
    def forward(self, x):
        x = F.relu(self.conv1(x))
        '''
        PaddlePaddle:使用paddle.nn.functional模块下的函数
        PyTorch:使用torch.nn.functional模块下的函数
        '''
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        '''
        PaddlePaddle:使用paddle.flatten函数
        PyTorch:使用torch.view函数
        '''
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, axis=1)
        return x

In [ ]

def rightness(predictions, labels):
    pred = paddle.argmax(predictions, axis=1)  #paddle.argmx
    rights = paddle.sum(pred == labels)  #paddle.sum
    return rights, len(labels) 

In [ ]

net = ConvNet()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Momentum(learning_rate=0.001, momentum=0.9, parameters=net.parameters())
'''
PaddlePaddle:使用paddle.nn.CrossEntropyLoss()
PyTorch:使用torch.nn.CrossEntropyLoss()

PaddlePaddle:使用paddle.optimizer.Momentum()定义优化器。
PyTorch:使用torch.optim.SGD定义优化器
'''

In [10]

record = []  # 记录准确率等数值的容器
weights = []  # 每若干步就记录一次卷积核

# 开始训练循环
for epoch in range(num_epochs):

    train_rights = []  # 记录训练数据集准确率的容器
    for batch_idx, (data, target) in enumerate(train_loader):  # 针对容器中的每一个批进行循环
        data = paddle.to_tensor(data)# PaddlePaddle:使用paddle.to_tensor()转换为PaddlePaddle张量。
        target = paddle.to_tensor(target)
        target = paddle.squeeze(target, axis=1)  # 调整目标维度与预测输出相匹配
        data.stop_gradient = False
        target.stop_gradient = False

        net.train()  # 给网络模型做标记,表示模型在训练集上训练
        output = net(data)  # 完成一次预测
        loss = criterion(output, target)  # 计算误差
        optimizer.clear_grad()  # 清空梯度
        '''
        PaddlePaddle:使用optimizer.clear_grad()清空梯度。
        PyTorch:使用optimizer.zero_grad()清空梯度。
        '''
        loss.backward()  # 反向传播
        optimizer.step()  # 一步随机梯度下降

        pred = paddle.argmax(output, axis=1)
        rights = paddle.sum(paddle.cast(pred == target, dtype='int32'))
        train_rights.append((rights.numpy()[0], data.shape[0]))

        if batch_idx % 100 == 0:  # 每间隔100个batch执行一次

            # train_r为一个二元组,分别记录训练集中分类正确的数量和该集合中总的样本数
            train_r = (np.sum([tup[0] for tup in train_rights]), np.sum([tup[1] for tup in train_rights]))
            # print(train_r)
            net.eval()  # 给网络模型做标记,表示模型在训练集上评估
            val_rights = []  # 记录校验数据集准确率的容器
            for (data, target) in val_loader:
                data = paddle.to_tensor(data)
                target = paddle.to_tensor(target)
                target = paddle.squeeze(target, axis=1)

                output = net(data)  # 完成一次预测
                pred = paddle.argmax(output, axis=1)
                rights = paddle.sum(paddle.cast(pred == target, dtype='int32'))
                # print(rights)
                val_rights.append((rights.numpy()[0], data.shape[0]))

            # val_r为一个二元组,分别记录校验集中分类正确的数量和该集合中总的样本数
            val_r = (np.sum([tup[0] for tup in val_rights]), np.sum([tup[1] for tup in val_rights]))
            # print(val_r)
            # 打印准确率等数值,其中正确率为本训练周期Epoch开始后到目前批次的正确率的平均值
            train_accuracy = 100. * train_r[0] / train_r[1]
            val_accuracy = 100. * val_r[0] / val_r[1]
            print('训练周期: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t训练正确率: {:.2f}%\t校验正确率: {:.2f}%'.format(
                epoch, train_r[1], len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item(),
                train_accuracy,
                val_accuracy))

            # 将准确率和权重等数值加载到容器中,以方便后续处理
            record.append((100. - train_accuracy, 100. - val_accuracy))
            weights.append([param.numpy() for param in net.parameters()])
'''
PaddlePaddle:使用PaddlePaddle的函数,如paddle.argmax()等
PyTorch:使用PyTorch的函数,如torch.max()等
'''
Training period: 0 [64/60000 (0%)] Loss: 0.294312 Training accuracy: 90.62% Validation accuracy: 94.04%
Training cycle: 0 [6464/60000 (11%)] Loss: 0.200471 Training accuracy rate: 93.04% Validation accuracy rate: 94.36%
Training period: 0 [12864/60000 (21%)] Loss: 0.172873 Training accuracy rate: 93.14% Validation accuracy rate: 94.64%
Training cycle: 0 [19264/60000 (32%)] Loss: 0.267028 Training accuracy rate: 93.29% Validation accuracy rate: 94.92%
Training cycle: 0 [25664/60000 (43%)] Loss: 0.082636 Training accuracy rate: 93.30% Validation accuracy rate: 95.18%
Training cycle: 0 [32064/60000 (53%)] Loss: 0.197915 Training accuracy rate: 93.43% Validation accuracy rate: 95.24%
Training cycle: 0 [38464/60000 (64%)] Loss: 0.189998 Training accuracy rate: 93.49% Validation accuracy rate: 95.32%
Training period: 0 [44864/60000 (75%)] Loss: 0.223470 Training accuracy rate: 93.55% Validation accuracy rate: 95.44%
Training cycle: 0 [51264/60000 (85%)] Loss: 0.112054 Training accuracy rate: 93.60% Validation accuracy rate: 95.62%
Training cycle: 0 [57664/60000 (96%)] Loss: 0.045446 Training accuracy rate: 93.71% Validation accuracy rate: 95.96%
Training cycle: 1 [64/60000 (0%)] Loss: 0.162490 Training accuracy: 96.88% Validation accuracy: 95.80%
Training cycle: 1 [6464/60000 (11%)] Loss: 0.154191 Training accuracy rate: 94.46% Validation accuracy rate: 95.98%
Training cycle: 1 [12864/60000 (21%)] Loss: 0.163062 Training accuracy rate: 94.61% Validation accuracy rate: 96.10%
Training cycle: 1 [19264/60000 (32%)] Loss: 0.098654 Training accuracy rate: 94.65% Validation accuracy rate: 96.14%
Training cycle: 1 [25664/60000 (43%)] Loss: 0.142942 Training accuracy rate: 94.71% Validation accuracy rate: 96.04%
Training cycle: 1 [32064/60000 (53%)] Loss: 0.058082 Training accuracy rate: 94.79% Validation accuracy rate: 96.24%
Training cycle: 1 [38464/60000 (64%)] Loss: 0.056331 Training accuracy rate: 94.83% Validation accuracy rate: 96.42%
Training cycle: 1 [44864/60000 (75%)] Loss: 0.296984 Training accuracy rate: 94.88% Validation accuracy rate: 96.38%
Training cycle: 1 [51264/60000 (85%)] Loss: 0.181992 Training accuracy rate: 94.89% Validation accuracy rate: 96.34%
Training cycle: 1 [57664/60000 (96%)] Loss: 0.269135 Training accuracy rate: 94.90% Validation accuracy rate: 96.48%
Training cycle: 2 [64/60000 (0%)] Loss: 0.245220 Training accuracy rate: 93.75% Validation accuracy rate: 96.32%
Training cycle: 2 [6464/60000 (11%)] Loss: 0.195959 Training accuracy rate: 95.41% Validation accuracy rate: 96.46%
Training cycle: 2 [12864/60000 (21%)] Loss: 0.076648 Training accuracy rate: 95.40% Validation accuracy rate: 96.56%
Training cycle: 2 [19264/60000 (32%)] Loss: 0.103266 Training accuracy rate: 95.38% Validation accuracy rate: 96.56%
Training cycle: 2 [25664/60000 (43%)] Loss: 0.326429 Training accuracy rate: 95.34% Validation accuracy rate: 96.68%
Training cycle: 2 [32064/60000 (53%)] Loss: 0.164600 Training accuracy rate: 95.46% Validation accuracy rate: 96.64%
Training cycle: 2 [38464/60000 (64%)] Loss: 0.100820 Training accuracy rate: 95.44% Validation accuracy rate: 96.80%
Training cycle: 2 [44864/60000 (75%)] Loss: 0.082001 Training accuracy rate: 95.48% Validation accuracy rate: 96.82%
Training cycle: 2 [51264/60000 (85%)] Loss: 0.119786 Training accuracy rate: 95.52% Validation accuracy rate: 96.82%
Training cycle: 2 [57664/60000 (96%)] Loss: 0.237487 Training accuracy rate: 95.53% Validation accuracy rate: 96.94%
Training cycle: 3 [64/60000 (0%)] Loss: 0.058802 Training accuracy rate: 100.00% Validation accuracy rate: 96.78%
Training cycle: 3 [6464/60000 (11%)] Loss: 0.155010 Training accuracy rate: 96.09% Validation accuracy rate: 97.06%
Training cycle: 3 [12864/60000 (21%)] Loss: 0.042440 Training accuracy rate: 95.99% Validation accuracy rate: 96.92%
Training cycle: 3 [19264/60000 (32%)] Loss: 0.118062 Training accuracy rate: 95.98% Validation accuracy rate: 97.02%
Training cycle: 3 [25664/60000 (43%)] Loss: 0.100673 Training accuracy rate: 96.02% Validation accuracy rate: 97.08%
Training cycle: 3 [32064/60000 (53%)] Loss: 0.172493 Training accuracy rate: 95.99% Validation accuracy rate: 97.00%
Training cycle: 3 [38464/60000 (64%)] Loss: 0.187268 Training accuracy rate: 95.98% Validation accuracy rate: 97.26%
Training cycle: 3 [44864/60000 (75%)] Loss: 0.061519 Training accuracy rate: 96.00% Validation accuracy rate: 97.30%
Training cycle: 3 [51264/60000 (85%)] Loss: 0.073172 Training accuracy rate: 96.07% Validation accuracy rate: 96.98%
Training cycle: 3 [57664/60000 (96%)] Loss: 0.055029 Training accuracy rate: 96.06% Validation accuracy rate: 97.08%
Training cycle: 4 [64/60000 (0%)] Loss: 0.065538 Training accuracy rate: 98.44% Validation accuracy rate: 97.18%
Training cycle: 4 [6464/60000 (11%)] Loss: 0.039354 Training accuracy rate: 96.49% Validation accuracy rate: 97.44%
Training cycle: 4 [12864/60000 (21%)] Loss: 0.113937 Training accuracy rate: 96.46% Validation accuracy rate: 97.06%
Training cycle: 4 [19264/60000 (32%)] Loss: 0.088527 Training accuracy rate: 96.46% Validation accuracy rate: 97.26%
Training cycle: 4 [25664/60000 (43%)] Loss: 0.074021 Training accuracy rate: 96.52% Validation accuracy rate: 97.40%
Training cycle: 4 [32064/60000 (53%)] Loss: 0.161240 Training accuracy rate: 96.51% Validation accuracy rate: 97.46%
Training cycle: 4 [38464/60000 (64%)] Loss: 0.067512 Training accuracy rate: 96.51% Validation accuracy rate: 97.48%
Training cycle: 4 [44864/60000 (75%)] Loss: 0.118248 Training accuracy rate: 96.44% Validation accuracy rate: 97.46%
Training cycle: 4 [51264/60000 (85%)] Loss: 0.059335 Training accuracy rate: 96.44% Validation accuracy rate: 97.54%
Training cycle: 4 [57664/60000 (96%)] Loss: 0.059810 Training accuracy rate: 96.48% Validation accuracy rate: 97.48%
Training cycle: 5 [64/60000 (0%)] Loss: 0.116151 Training accuracy rate: 96.88% Validation accuracy rate: 97.38%
Training cycle: 5 [6464/60000 (11%)] Loss: 0.160240 Training accuracy rate: 96.97% Validation accuracy rate: 97.50%
Training cycle: 5 [12864/60000 (21%)] Loss: 0.272595 Training accuracy rate: 96.76% Validation accuracy rate: 97.48%
Training cycle: 5 [19264/60000 (32%)] Loss: 0.092262 Training accuracy rate: 96.65% Validation accuracy rate: 97.62%
Training cycle: 5 [25664/60000 (43%)] Loss: 0.061020 Training accuracy rate: 96.64% Validation accuracy rate: 97.62%
Training cycle: 5 [32064/60000 (53%)] Loss: 0.140946 Training accuracy rate: 96.65% Validation accuracy rate: 97.62%
Training cycle: 5 [38464/60000 (64%)] Loss: 0.172838 Training accuracy rate: 96.69% Validation accuracy rate: 97.60%
Training cycle: 5 [44864/60000 (75%)] Loss: 0.141658 Training accuracy rate: 96.69% Validation accuracy rate: 97.70%
Training cycle: 5 [51264/60000 (85%)] Loss: 0.102514 Training accuracy rate: 96.72% Validation accuracy rate: 97.80%
Training cycle: 5 [57664/60000 (96%)] Loss: 0.206360 Training accuracy rate: 96.74% Validation accuracy rate: 97.68%
Training cycle: 6 [64/60000 (0%)] Loss: 0.174180 Training accuracy rate: 95.31% Validation accuracy rate: 97.74%
Training cycle: 6 [6464/60000 (11%)] Loss: 0.192811 Training accuracy rate: 96.41% Validation accuracy rate: 97.80%
Training cycle: 6 [12864/60000 (21%)] Loss: 0.079086 Training accuracy rate: 96.49% Validation accuracy rate: 97.76%
Training cycle: 6 [19264/60000 (32%)] Loss: 0.094231 Training accuracy rate: 96.68% Validation accuracy rate: 97.68%
Training cycle: 6 [25664/60000 (43%)] Loss: 0.120969 Training accuracy rate: 96.82% Validation accuracy rate: 97.72%
Training cycle: 6 [32064/60000 (53%)] Loss: 0.039326 Training accuracy rate: 96.90% Validation accuracy rate: 97.78%
Training cycle: 6 [38464/60000 (64%)] Loss: 0.054882 Training accuracy rate: 97.00% Validation accuracy rate: 97.70%
Training cycle: 6 [44864/60000 (75%)] Loss: 0.161812 Training accuracy rate: 97.01% Validation accuracy rate: 97.76%
Training cycle: 6 [51264/60000 (85%)] Loss: 0.109295 Training accuracy rate: 96.99% Validation accuracy rate: 97.74%
Training cycle: 6 [57664/60000 (96%)] Loss: 0.092668 Training accuracy rate: 97.02% Validation accuracy rate: 97.68%
Training cycle: 7 [64/60000 (0%)] Loss: 0.072402 Training accuracy rate: 98.44% Validation accuracy rate: 97.88%
Training cycle: 7 [6464/60000 (11%)] Loss: 0.034500 Training accuracy rate: 97.28% Validation accuracy rate: 97.90%
Training cycle: 7 [12864/60000 (21%)] Loss: 0.032573 Training accuracy rate: 97.16% Validation accuracy rate: 97.88%
Training cycle: 7 [19264/60000 (32%)] Loss: 0.123999 Training accuracy rate: 97.13% Validation accuracy rate: 97.84%
Training cycle: 7 [25664/60000 (43%)] Loss: 0.116971 Training accuracy rate: 97.16% Validation accuracy rate: 97.82%
Training cycle: 7 [32064/60000 (53%)] Loss: 0.034661 Training accuracy rate: 97.20% Validation accuracy rate: 97.88%
Training cycle: 7 [38464/60000 (64%)] Loss: 0.060708 Training accuracy rate: 97.20% Validation accuracy rate: 97.94%
Training cycle: 7 [44864/60000 (75%)] Loss: 0.125912 Training accuracy rate: 97.23% Validation accuracy rate: 97.96%
Training cycle: 7 [51264/60000 (85%)] Loss: 0.186194 Training accuracy rate: 97.22% Validation accuracy rate: 97.98%
Training cycle: 7 [57664/60000 (96%)] Loss: 0.087361 Training accuracy rate: 97.23% Validation accuracy rate: 98.08%
Training cycle: 8 [64/60000 (0%)] Loss: 0.009403 Training accuracy rate: 100.00% Validation accuracy rate: 97.86%
Training cycle: 8 [6464/60000 (11%)] Loss: 0.326886 Training accuracy rate: 97.42% Validation accuracy rate: 98.00%
Training cycle: 8 [12864/60000 (21%)] Loss: 0.060366 Training accuracy rate: 97.32% Validation accuracy rate: 98.02%
Training cycle: 8 [19264/60000 (32%)] Loss: 0.161973 Training accuracy rate: 97.31% Validation accuracy rate: 98.10%
Training cycle: 8 [25664/60000 (43%)] Loss: 0.071160 Training accuracy rate: 97.35% Validation accuracy rate: 98.04%
Training cycle: 8 [32064/60000 (53%)] Loss: 0.017569 Training accuracy rate: 97.37% Validation accuracy rate: 98.10%
Training cycle: 8 [38464/60000 (64%)] Loss: 0.050663 Training accuracy rate: 97.38% Validation accuracy rate: 98.20%
Training cycle: 8 [44864/60000 (75%)] Loss: 0.222503 Training accuracy rate: 97.37% Validation accuracy rate: 98.18%
Training cycle: 8 [51264/60000 (85%)] Loss: 0.109463 Training accuracy rate: 97.36% Validation accuracy rate: 98.14%
Training cycle: 8 [57664/60000 (96%)] Loss: 0.080445 Training accuracy rate: 97.35% Validation accuracy rate: 98.10%
Training cycle: 9 [64/60000 (0%)] Loss: 0.017284 Training accuracy rate: 100.00% Validation accuracy rate: 98.06%
Training cycle: 9 [6464/60000 (11%)] Loss: 0.045179 Training accuracy rate: 97.56% Validation accuracy rate: 98.20%
Training cycle: 9 [12864/60000 (21%)] Loss: 0.031334 Training accuracy rate: 97.56% Validation accuracy rate: 98.04%
Training cycle: 9 [19264/60000 (32%)] Loss: 0.038351 Training accuracy rate: 97.60% Validation accuracy rate: 98.14%
Training cycle: 9 [25664/60000 (43%)] Loss: 0.121472 Training accuracy rate: 97.59% Validation accuracy rate: 98.30%
Training cycle: 9 [32064/60000 (53%)] Loss: 0.085770 Training accuracy rate: 97.51% Validation accuracy rate: 98.28%
Training cycle: 9 [38464/60000 (64%)] Loss: 0.059464 Training accuracy rate: 97.48% Validation accuracy rate: 98.32%
Training cycle: 9 [44864/60000 (75%)] Loss: 0.095663 Training accuracy rate: 97.57% Validation accuracy rate: 98.08%
Training cycle: 9 [51264/60000 (85%)] Loss: 0.113621 Training accuracy rate: 97.54% Validation accuracy rate: 98.34%
Training cycle: 9 [57664/60000 (96%)] Loss: 0.268711 Training accuracy rate: 97.52% Validation accuracy rate: 98.30%
Training cycle: 10 [64/60000 (0%)] Loss: 0.084389 Training accuracy rate: 96.88% Validation accuracy rate: 98.22%
Training cycle: 10 [6464/60000 (11%)] Loss: 0.055922 Training accuracy rate: 97.62% Validation accuracy rate: 98.34%
Training cycle: 10 [12864/60000 (21%)] Loss: 0.125814 Training accuracy rate: 97.71% Validation accuracy rate: 98.24%
Training cycle: 10 [19264/60000 (32%)] Loss: 0.045647 Training accuracy rate: 97.67% Validation accuracy rate: 98.18%
Training cycle: 10 [25664/60000 (43%)] Loss: 0.054552 Training accuracy rate: 97.70% Validation accuracy rate: 98.34%
Training cycle: 10 [32064/60000 (53%)] Loss: 0.017988 Training accuracy rate: 97.65% Validation accuracy rate: 98.42%
Training cycle: 10 [38464/60000 (64%)] Loss: 0.034680 Training accuracy rate: 97.68% Validation accuracy rate: 98.34%
Training cycle: 10 [44864/60000 (75%)] Loss: 0.068128 Training accuracy rate: 97.68% Validation accuracy rate: 98.24%
Training cycle: 10 [51264/60000 (85%)] Loss: 0.048264 Training accuracy rate: 97.67% Validation accuracy rate: 98.38%
Training cycle: 10 [57664/60000 (96%)] Loss: 0.040969 Training accuracy rate: 97.68% Validation accuracy rate: 98.32%
Training cycle: 11 [64/60000 (0%)] Loss: 0.095555 Training accuracy rate: 98.44% Validation accuracy rate: 98.28%
Training cycle: 11 [6464/60000 (11%)] Loss: 0.080019 Training accuracy rate: 97.66% Validation accuracy rate: 98.26%
Training cycle: 11 [12864/60000 (21%)] Loss: 0.099818 Training accuracy rate: 97.64% Validation accuracy rate: 98.36%
Training cycle: 11 [19264/60000 (32%)] Loss: 0.057005 Training accuracy rate: 97.72% Validation accuracy rate: 98.26%
Training cycle: 11 [25664/60000 (43%)] Loss: 0.032954 Training accuracy rate: 97.76% Validation accuracy rate: 98.54%
Training cycle: 11 [32064/60000 (53%)] Loss: 0.053869 Training accuracy rate: 97.73% Validation accuracy rate: 98.36%
Training cycle: 11 [38464/60000 (64%)] Loss: 0.062431 Training accuracy rate: 97.70% Validation accuracy rate: 98.30%
Training cycle: 11 [44864/60000 (75%)] Loss: 0.072020 Training accuracy rate: 97.75% Validation accuracy rate: 98.40%
Training cycle: 11 [51264/60000 (85%)] Loss: 0.010781 Training accuracy rate: 97.75% Validation accuracy rate: 98.48%
Training cycle: 11 [57664/60000 (96%)] Loss: 0.017668 Training accuracy rate: 97.76% Validation accuracy rate: 98.58%
Training cycle: 12 [64/60000 (0%)] Loss: 0.012428 Training accuracy rate: 100.00% Validation accuracy rate: 98.38%
Training cycle: 12 [6464/60000 (11%)] Loss: 0.055729 Training accuracy rate: 98.21% Validation accuracy rate: 98.46%
Training cycle: 12 [12864/60000 (21%)] Loss: 0.097194 Training accuracy rate: 97.96% Validation accuracy rate: 98.48%
Training cycle: 12 [19264/60000 (32%)] Loss: 0.071802 Training accuracy rate: 97.87% Validation accuracy rate: 98.32%
Training cycle: 12 [25664/60000 (43%)] Loss: 0.078902 Training accuracy rate: 97.91% Validation accuracy rate: 98.38%
Training cycle: 12 [32064/60000 (53%)] Loss: 0.247189 Training accuracy rate: 97.88% Validation accuracy rate: 98.54%
Training cycle: 12 [38464/60000 (64%)] Loss: 0.037298 Training accuracy rate: 97.85% Validation accuracy rate: 98.56%
Training cycle: 12 [44864/60000 (75%)] Loss: 0.026617 Training accuracy rate: 97.88% Validation accuracy rate: 98.48%
Training cycle: 12 [51264/60000 (85%)] Loss: 0.024428 Training accuracy rate: 97.86% Validation accuracy rate: 98.46%
Training cycle: 12 [57664/60000 (96%)] Loss: 0.017038 Training accuracy rate: 97.86% Validation accuracy rate: 98.52%
Training cycle: 13 [64/60000 (0%)] Loss: 0.067459 Training accuracy rate: 98.44% Validation accuracy rate: 98.60%
Training cycle: 13 [6464/60000 (11%)] Loss: 0.032820 Training accuracy rate: 97.88% Validation accuracy rate: 98.56%
Training cycle: 13 [12864/60000 (21%)] Loss: 0.009003 Training accuracy rate: 98.03% Validation accuracy rate: 98.60%
Training cycle: 13 [19264/60000 (32%)] Loss: 0.078738 Training accuracy rate: 98.01% Validation accuracy rate: 98.44%
Training cycle: 13 [25664/60000 (43%)] Loss: 0.182921 Training accuracy rate: 97.94% Validation accuracy rate: 98.46%
Training cycle: 13 [32064/60000 (53%)] Loss: 0.118050 Training accuracy rate: 97.90% Validation accuracy rate: 98.58%
Training cycle: 13 [38464/60000 (64%)] Loss: 0.024995 Training accuracy rate: 97.90% Validation accuracy rate: 98.52%
Training cycle: 13 [44864/60000 (75%)] Loss: 0.065351 Training accuracy rate: 97.91% Validation accuracy rate: 98.50%
Training cycle: 13 [51264/60000 (85%)] Loss: 0.037466 Training accuracy rate: 97.94% Validation accuracy rate: 98.50%
Training cycle: 13 [57664/60000 (96%)] Loss: 0.086423 Training accuracy rate: 97.91% Validation accuracy rate: 98.54%
Training cycle: 14 [64/60000 (0%)] Loss: 0.024936 Training accuracy rate: 100.00% Validation accuracy rate: 98.68%
Training cycle: 14 [6464/60000 (11%)] Loss: 0.035491 Training accuracy rate: 98.24% Validation accuracy rate: 98.44%
Training cycle: 14 [12864/60000 (21%)] Loss: 0.097156 Training accuracy rate: 98.27% Validation accuracy rate: 98.56%
Training cycle: 14 [19264/60000 (32%)] Loss: 0.163661 Training accuracy rate: 98.22% Validation accuracy rate: 98.60%
Training cycle: 14 [25664/60000 (43%)] Loss: 0.070727 Training accuracy rate: 98.09% Validation accuracy rate: 98.44%
Training cycle: 14 [32064/60000 (53%)] Loss: 0.006786 Training accuracy rate: 98.10% Validation accuracy rate: 98.42%
Training cycle: 14 [38464/60000 (64%)] Loss: 0.041515 Training accuracy rate: 98.05% Validation accuracy rate: 98.54%
Training cycle: 14 [44864/60000 (75%)] Loss: 0.009068 Training accuracy rate: 98.05% Validation accuracy rate: 98.40%
Training cycle: 14 [51264/60000 (85%)] Loss: 0.017752 Training accuracy rate: 98.06% Validation accuracy rate: 98.66%
Training cycle: 14 [57664/60000 (96%)] Loss: 0.070509 Training accuracy rate: 98.03% Validation accuracy rate: 98.58%
Training cycle: 15 [64/60000 (0%)] Loss: 0.026125 Training accuracy rate: 98.44% Validation accuracy rate: 98.58%
Training cycle: 15 [6464/60000 (11%)] Loss: 0.016736 Training accuracy rate: 98.04% Validation accuracy rate: 98.40%
Training cycle: 15 [12864/60000 (21%)] Loss: 0.035270 Training accuracy rate: 97.87% Validation accuracy rate: 98.48%
Training cycle: 15 [19264/60000 (32%)] Loss: 0.015142 Training accuracy rate: 98.00% Validation accuracy rate: 98.76%
Training cycle: 15 [25664/60000 (43%)] Loss: 0.131374 Training accuracy rate: 98.00% Validation accuracy rate: 98.56%
Training cycle: 15 [32064/60000 (53%)] Loss: 0.114820 Training accuracy rate: 97.99% Validation accuracy rate: 98.52%
Training cycle: 15 [38464/60000 (64%)] Loss: 0.066392 Training accuracy rate: 97.97% Validation accuracy rate: 98.56%
Training cycle: 15 [44864/60000 (75%)] Loss: 0.007783 Training accuracy rate: 98.03% Validation accuracy rate: 98.62%
Training cycle: 15 [51264/60000 (85%)] Loss: 0.025519 Training accuracy rate: 98.02% Validation accuracy rate: 98.74%
Training cycle: 15 [57664/60000 (96%)] Loss: 0.029439 Training accuracy rate: 98.02% Validation accuracy rate: 98.66%
Training cycle: 16 [64/60000 (0%)] Loss: 0.101036 Training accuracy rate: 98.44% Validation accuracy rate: 98.58%
Training cycle: 16 [6464/60000 (11%)] Loss: 0.091292 Training accuracy rate: 98.04% Validation accuracy rate: 98.52%
Training cycle: 16 [12864/60000 (21%)] Loss: 0.057263 Training accuracy rate: 97.97% Validation accuracy rate: 98.74%
Training cycle: 16 [19264/60000 (32%)] Loss: 0.012148 Training accuracy rate: 98.12% Validation accuracy rate: 98.72%
Training cycle: 16 [25664/60000 (43%)] Loss: 0.077445 Training accuracy rate: 98.17% Validation accuracy rate: 98.62%
Training cycle: 16 [32064/60000 (53%)] Loss: 0.026971 Training accuracy rate: 98.18% Validation accuracy rate: 98.72%
Training cycle: 16 [38464/60000 (64%)] Loss: 0.014428 Training accuracy rate: 98.19% Validation accuracy rate: 98.76%
Training cycle: 16 [44864/60000 (75%)] Loss: 0.048128 Training accuracy rate: 98.17% Validation accuracy rate: 98.72%
Training cycle: 16 [51264/60000 (85%)] Loss: 0.020594 Training accuracy rate: 98.18% Validation accuracy rate: 98.76%
Training cycle: 16 [57664/60000 (96%)] Loss: 0.048687 Training accuracy rate: 98.14% Validation accuracy rate: 98.66%
Training cycle: 17 [64/60000 (0%)] Loss: 0.116186 Training accuracy rate: 98.44% Validation accuracy rate: 98.68%
Training cycle: 17 [6464/60000 (11%)] Loss: 0.045756 Training accuracy rate: 98.44% Validation accuracy rate: 98.70%
Training cycle: 17 [12864/60000 (21%)] Loss: 0.093706 Training accuracy rate: 98.31% Validation accuracy rate: 98.82%
Training cycle: 17 [19264/60000 (32%)] Loss: 0.153490 Training accuracy rate: 98.23% Validation accuracy rate: 98.58%
Training cycle: 17 [25664/60000 (43%)] Loss: 0.037656 Training accuracy rate: 98.26% Validation accuracy rate: 98.64%
Training cycle: 17 [32064/60000 (53%)] Loss: 0.089268 Training accuracy rate: 98.28% Validation accuracy rate: 98.68%
Training cycle: 17 [38464/60000 (64%)] Loss: 0.007872 Training accuracy rate: 98.24% Validation accuracy rate: 98.78%
Training cycle: 17 [44864/60000 (75%)] Loss: 0.011014 Training accuracy rate: 98.21% Validation accuracy rate: 98.60%
Training cycle: 17 [51264/60000 (85%)] Loss: 0.082866 Training accuracy rate: 98.20% Validation accuracy rate: 98.72%
Training cycle: 17 [57664/60000 (96%)] Loss: 0.223709 Training accuracy rate: 98.23% Validation accuracy rate: 98.74%
Training cycle: 18 [64/60000 (0%)] Loss: 0.012267 Training accuracy rate: 100.00% Validation accuracy rate: 98.84%
Training cycle: 18 [6464/60000 (11%)] Loss: 0.016755 Training accuracy rate: 98.24% Validation accuracy rate: 98.74%
Training cycle: 18 [12864/60000 (21%)] Loss: 0.071003 Training accuracy rate: 98.27% Validation accuracy rate: 98.90%
Training cycle: 18 [19264/60000 (32%)] Loss: 0.021377 Training accuracy rate: 98.37% Validation accuracy rate: 98.78%
Training cycle: 18 [25664/60000 (43%)] Loss: 0.019434 Training accuracy rate: 98.35% Validation accuracy rate: 98.92%
Training cycle: 18 [32064/60000 (53%)] Loss: 0.046915 Training accuracy rate: 98.30% Validation accuracy rate: 98.76%
Training cycle: 18 [38464/60000 (64%)] Loss: 0.029202 Training accuracy rate: 98.25% Validation accuracy rate: 98.90%
Training cycle: 18 [44864/60000 (75%)] Loss: 0.089349 Training accuracy rate: 98.28% Validation accuracy rate: 98.86%
Training cycle: 18 [51264/60000 (85%)] Loss: 0.014138 Training accuracy rate: 98.29% Validation accuracy rate: 98.82%
Training cycle: 18 [57664/60000 (96%)] Loss: 0.059072 Training accuracy rate: 98.27% Validation accuracy rate: 98.76%
Training cycle: 19 [64/60000 (0%)] Loss: 0.064476 Training accuracy rate: 95.31% Validation accuracy rate: 98.80%
Training cycle: 19 [6464/60000 (11%)] Loss: 0.045951 Training accuracy rate: 98.02% Validation accuracy rate: 98.78%
Training cycle: 19 [12864/60000 (21%)] Loss: 0.009943 Training accuracy rate: 98.28% Validation accuracy rate: 98.88%
Training cycle: 19 [19264/60000 (32%)] Loss: 0.006517 Training accuracy rate: 98.28% Validation accuracy rate: 98.72%
Training cycle: 19 [25664/60000 (43%)] Loss: 0.117733 Training accuracy rate: 98.32% Validation accuracy rate: 99.00%
Training cycle: 19 [32064/60000 (53%)] Loss: 0.012346 Training accuracy rate: 98.30% Validation accuracy rate: 98.82%
Training cycle: 19 [38464/60000 (64%)] Loss: 0.004596 Training accuracy rate: 98.30% Validation accuracy rate: 98.86%
Training cycle: 19 [44864/60000 (75%)] Loss: 0.057226 Training accuracy rate: 98.31% Validation accuracy rate: 98.72%
Training cycle: 19 [51264/60000 (85%)] Loss: 0.045235 Training accuracy rate: 98.33% Validation accuracy rate: 99.00%
Training cycle: 19 [57664/60000 (96%)] Loss: 0.069116 Training accuracy rate: 98.33% Validation accuracy rate: 98.94%

In [11]

paddle.save(net.state_dict(), 'minst_conv_checkpoint.pdparams')
'''
PaddlePaddle:使用paddle.save()保存模型参数。
PyTorch:使用torch.save()保存模型参数。
'''
'\nPaddlePaddle: Use paddle.save() to save model parameters. \nPyTorch: Use torch.save() to save model parameters. \n'

In [12]

# 提取第一层卷积层的卷积核
plt.figure(figsize=(10, 7))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(net.conv1.weight.numpy()[i, 0, ...])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))

<Figure size 1000x700 with 4 Axes>

In [13]

# 绘制第二层的卷积核
plt.figure(figsize=(15, 10))
for i in range(4):
    for j in range(8):
        plt.subplot(4, 8, i * 8 + j + 1)
        plt.imshow(net.conv2.weight.numpy()[j, i, ...])

<Figure size 1500x1000 with 32 Axes>

Guess you like

Origin blog.csdn.net/m0_68036862/article/details/131484015