初阶学习:用卷积神经网络实现手写数字识别

简介:

在这篇博客文章中,我们将深入探讨如何使用卷积神经网络(Convolutional Neural Networks,简称CNN)来实现手写数字识别。这是机器学习中一个非常基础且广泛应用的问题,理解和掌握其解决方法,对于入门深度学习非常有帮助。

我们将从介绍卷积神经网络的基本原理开始,然后详细讲解如何构建和训练一个简单的CNN模型,最后我们会展示如何使用这个模型来识别手写数字。在这个过程中,我们将详细解释每一步的原理和代码,使得读者能够完全理解并且能够自己动手实现。

无论你是一位刚刚接触机器学习的新手,还是一位有一定基础并希望进一步深入理解CNN的研究者,我们都相信你能在这篇文章中收获到新的知识和启发。让我们一起开始这次探索之旅,感受卷积神经网络在手写数字识别上的强大之处!

# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. 
# This directory will be recovered automatically after resetting environment. 
!ls /home/aistudio/data

In [ ]

# 查看工作区文件, 该目录下的变更将会持久保存. 请及时清理不必要的文件, 避免加载过慢.
# View personal work directory. 
# All changes under this directory will be kept even after reset. 
# Please clean unnecessary files in time to speed up environment loading. 
!ls /home/aistudio/work

In [ ]

# # 如果需要进行持久化安装, 需要使用持久化路径, 如下方代码示例:
# # If a persistence installation is required, 
# # you need to use the persistence path as the following: 
# !mkdir /home/aistudio/external-libraries
# !pip install beautifulsoup4 -t /home/aistudio/external-libraries
mkdir: cannot create directory ‘/home/aistudio/external-libraries’: File exists
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting beautifulsoup4
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl (142 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 143.0/143.0 kB 7.6 MB/s eta 0:00:00
Collecting soupsieve>1.2
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/49/37/673d6490efc51ec46d198c75903d99de59baffdd47aea3d071b80a9e4e89/soupsieve-2.4.1-py3-none-any.whl (36 kB)
Installing collected packages: soupsieve, beautifulsoup4
Successfully installed beautifulsoup4-4.12.2 soupsieve-2.4.1
WARNING: Target directory /home/aistudio/external-libraries/bs4 already exists. Specify --upgrade to force replacement.
WARNING: Target directory /home/aistudio/external-libraries/soupsieve-2.4.1.dist-info already exists. Specify --upgrade to force replacement.
WARNING: Target directory /home/aistudio/external-libraries/beautifulsoup4-4.12.2.dist-info already exists. Specify --upgrade to force replacement.
WARNING: Target directory /home/aistudio/external-libraries/soupsieve already exists. Specify --upgrade to force replacement.

[notice] A new release of pip available: 22.1.2 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

In [ ]

# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可: 
# Also add the following code, 
# so that every time the environment (kernel) starts, 
# just run the following code: 
import sys 
sys.path.append('/home/aistudio/external-libraries')

In [ ]

import paddle
import paddle.nn as nn
import paddle.optimizer as optim
import paddle.nn.functional as F
from paddle.vision.datasets import MNIST
import paddle.vision.transforms as transforms
import paddle.vision.models as models
import matplotlib.pyplot as plt
import numpy as np
# paddlepaddle的导入
'''
PaddlePaddle: import paddle ; import paddle.vision.transforms as transforms...
PyTorch: import torch ; import torchvision.transforms as transforms...
'''
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized

In [ ]

image_size = 28
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.001

train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transforms.ToTensor())
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transforms.ToTensor())
'''
PaddlePaddle:使用paddle.vision.datasets.MNIST加载MNIST数据集。
PyTorch:使用torchvision.datasets.MNIST加载MNIST数据集。
'''
# Generate random indices for validation and test sets
permutes = np.random.permutation(range(len(train_dataset)))
indices_val = permutes[:5000]
indices_test = permutes[5000:]

# Create Subset datasets for validation and test sets
val_dataset = paddle.io.Subset(train_dataset, indices_val)
test_dataset = paddle.io.Subset(train_dataset, indices_test)
'''
PaddlePaddle:使用paddle.io.Subset创建子集数据集。
PyTorch:
sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test)
'''
# Create data loaders for train, validation, and test sets
train_loader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = paddle.io.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
'''
PaddlePaddle:使用paddle.io.DataLoader创建
PyTorch:使用torch.utils.data.DataLoader创建
'''

In [ ]


depth = [4, 8]
class ConvNet(nn.Layer):
    '''
    PaddlePaddle:使用paddle.nn.Layer作为基类
    PyTorch:使用torch.nn.Module作为基类
    '''
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2D(1, 4, 5, padding=2)
        self.pool = nn.MaxPool2D(2, 2)
        self.conv2 = nn.Conv2D(depth[0], depth[1], 5, padding=2)
        self.fc1 = nn.Linear(image_size // 4 * image_size // 4 * 8, 512)
        self.fc2 = nn.Linear(512, num_classes)
        '''
        PaddlePaddle:使用paddle.nn模块下的层定义,如paddle.nn.Conv2D
        PyTorch:使用torch.nn模块下的层定义,如torch.nn.Conv2d
        '''
    def forward(self, x):
        x = F.relu(self.conv1(x))
        '''
        PaddlePaddle:使用paddle.nn.functional模块下的函数
        PyTorch:使用torch.nn.functional模块下的函数
        '''
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = paddle.flatten(x, start_axis=1, stop_axis=-1)
        '''
        PaddlePaddle:使用paddle.flatten函数
        PyTorch:使用torch.view函数
        '''
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.fc2(x)
        x = F.log_softmax(x, axis=1)
        return x

In [ ]

def rightness(predictions, labels):
    pred = paddle.argmax(predictions, axis=1)  #paddle.argmx
    rights = paddle.sum(pred == labels)  #paddle.sum
    return rights, len(labels) 

In [ ]

net = ConvNet()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Momentum(learning_rate=0.001, momentum=0.9, parameters=net.parameters())
'''
PaddlePaddle:使用paddle.nn.CrossEntropyLoss()
PyTorch:使用torch.nn.CrossEntropyLoss()

PaddlePaddle:使用paddle.optimizer.Momentum()定义优化器。
PyTorch:使用torch.optim.SGD定义优化器
'''

In [10]

record = []  # 记录准确率等数值的容器
weights = []  # 每若干步就记录一次卷积核

# 开始训练循环
for epoch in range(num_epochs):

    train_rights = []  # 记录训练数据集准确率的容器
    for batch_idx, (data, target) in enumerate(train_loader):  # 针对容器中的每一个批进行循环
        data = paddle.to_tensor(data)# PaddlePaddle:使用paddle.to_tensor()转换为PaddlePaddle张量。
        target = paddle.to_tensor(target)
        target = paddle.squeeze(target, axis=1)  # 调整目标维度与预测输出相匹配
        data.stop_gradient = False
        target.stop_gradient = False

        net.train()  # 给网络模型做标记,表示模型在训练集上训练
        output = net(data)  # 完成一次预测
        loss = criterion(output, target)  # 计算误差
        optimizer.clear_grad()  # 清空梯度
        '''
        PaddlePaddle:使用optimizer.clear_grad()清空梯度。
        PyTorch:使用optimizer.zero_grad()清空梯度。
        '''
        loss.backward()  # 反向传播
        optimizer.step()  # 一步随机梯度下降

        pred = paddle.argmax(output, axis=1)
        rights = paddle.sum(paddle.cast(pred == target, dtype='int32'))
        train_rights.append((rights.numpy()[0], data.shape[0]))

        if batch_idx % 100 == 0:  # 每间隔100个batch执行一次

            # train_r为一个二元组,分别记录训练集中分类正确的数量和该集合中总的样本数
            train_r = (np.sum([tup[0] for tup in train_rights]), np.sum([tup[1] for tup in train_rights]))
            # print(train_r)
            net.eval()  # 给网络模型做标记,表示模型在训练集上评估
            val_rights = []  # 记录校验数据集准确率的容器
            for (data, target) in val_loader:
                data = paddle.to_tensor(data)
                target = paddle.to_tensor(target)
                target = paddle.squeeze(target, axis=1)

                output = net(data)  # 完成一次预测
                pred = paddle.argmax(output, axis=1)
                rights = paddle.sum(paddle.cast(pred == target, dtype='int32'))
                # print(rights)
                val_rights.append((rights.numpy()[0], data.shape[0]))

            # val_r为一个二元组,分别记录校验集中分类正确的数量和该集合中总的样本数
            val_r = (np.sum([tup[0] for tup in val_rights]), np.sum([tup[1] for tup in val_rights]))
            # print(val_r)
            # 打印准确率等数值,其中正确率为本训练周期Epoch开始后到目前批次的正确率的平均值
            train_accuracy = 100. * train_r[0] / train_r[1]
            val_accuracy = 100. * val_r[0] / val_r[1]
            print('训练周期: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t训练正确率: {:.2f}%\t校验正确率: {:.2f}%'.format(
                epoch, train_r[1], len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item(),
                train_accuracy,
                val_accuracy))

            # 将准确率和权重等数值加载到容器中,以方便后续处理
            record.append((100. - train_accuracy, 100. - val_accuracy))
            weights.append([param.numpy() for param in net.parameters()])
'''
PaddlePaddle:使用PaddlePaddle的函数,如paddle.argmax()等
PyTorch:使用PyTorch的函数,如torch.max()等
'''
训练周期: 0 [64/60000 (0%)]	Loss: 0.294312	训练正确率: 90.62%	校验正确率: 94.04%
训练周期: 0 [6464/60000 (11%)]	Loss: 0.200471	训练正确率: 93.04%	校验正确率: 94.36%
训练周期: 0 [12864/60000 (21%)]	Loss: 0.172873	训练正确率: 93.14%	校验正确率: 94.64%
训练周期: 0 [19264/60000 (32%)]	Loss: 0.267028	训练正确率: 93.29%	校验正确率: 94.92%
训练周期: 0 [25664/60000 (43%)]	Loss: 0.082636	训练正确率: 93.30%	校验正确率: 95.18%
训练周期: 0 [32064/60000 (53%)]	Loss: 0.197915	训练正确率: 93.43%	校验正确率: 95.24%
训练周期: 0 [38464/60000 (64%)]	Loss: 0.189998	训练正确率: 93.49%	校验正确率: 95.32%
训练周期: 0 [44864/60000 (75%)]	Loss: 0.223470	训练正确率: 93.55%	校验正确率: 95.44%
训练周期: 0 [51264/60000 (85%)]	Loss: 0.112054	训练正确率: 93.60%	校验正确率: 95.62%
训练周期: 0 [57664/60000 (96%)]	Loss: 0.045446	训练正确率: 93.71%	校验正确率: 95.96%
训练周期: 1 [64/60000 (0%)]	Loss: 0.162490	训练正确率: 96.88%	校验正确率: 95.80%
训练周期: 1 [6464/60000 (11%)]	Loss: 0.154191	训练正确率: 94.46%	校验正确率: 95.98%
训练周期: 1 [12864/60000 (21%)]	Loss: 0.163062	训练正确率: 94.61%	校验正确率: 96.10%
训练周期: 1 [19264/60000 (32%)]	Loss: 0.098654	训练正确率: 94.65%	校验正确率: 96.14%
训练周期: 1 [25664/60000 (43%)]	Loss: 0.142942	训练正确率: 94.71%	校验正确率: 96.04%
训练周期: 1 [32064/60000 (53%)]	Loss: 0.058082	训练正确率: 94.79%	校验正确率: 96.24%
训练周期: 1 [38464/60000 (64%)]	Loss: 0.056331	训练正确率: 94.83%	校验正确率: 96.42%
训练周期: 1 [44864/60000 (75%)]	Loss: 0.296984	训练正确率: 94.88%	校验正确率: 96.38%
训练周期: 1 [51264/60000 (85%)]	Loss: 0.181992	训练正确率: 94.89%	校验正确率: 96.34%
训练周期: 1 [57664/60000 (96%)]	Loss: 0.269135	训练正确率: 94.90%	校验正确率: 96.48%
训练周期: 2 [64/60000 (0%)]	Loss: 0.245220	训练正确率: 93.75%	校验正确率: 96.32%
训练周期: 2 [6464/60000 (11%)]	Loss: 0.195959	训练正确率: 95.41%	校验正确率: 96.46%
训练周期: 2 [12864/60000 (21%)]	Loss: 0.076648	训练正确率: 95.40%	校验正确率: 96.56%
训练周期: 2 [19264/60000 (32%)]	Loss: 0.103266	训练正确率: 95.38%	校验正确率: 96.56%
训练周期: 2 [25664/60000 (43%)]	Loss: 0.326429	训练正确率: 95.34%	校验正确率: 96.68%
训练周期: 2 [32064/60000 (53%)]	Loss: 0.164600	训练正确率: 95.46%	校验正确率: 96.64%
训练周期: 2 [38464/60000 (64%)]	Loss: 0.100820	训练正确率: 95.44%	校验正确率: 96.80%
训练周期: 2 [44864/60000 (75%)]	Loss: 0.082001	训练正确率: 95.48%	校验正确率: 96.82%
训练周期: 2 [51264/60000 (85%)]	Loss: 0.119786	训练正确率: 95.52%	校验正确率: 96.82%
训练周期: 2 [57664/60000 (96%)]	Loss: 0.237487	训练正确率: 95.53%	校验正确率: 96.94%
训练周期: 3 [64/60000 (0%)]	Loss: 0.058802	训练正确率: 100.00%	校验正确率: 96.78%
训练周期: 3 [6464/60000 (11%)]	Loss: 0.155010	训练正确率: 96.09%	校验正确率: 97.06%
训练周期: 3 [12864/60000 (21%)]	Loss: 0.042440	训练正确率: 95.99%	校验正确率: 96.92%
训练周期: 3 [19264/60000 (32%)]	Loss: 0.118062	训练正确率: 95.98%	校验正确率: 97.02%
训练周期: 3 [25664/60000 (43%)]	Loss: 0.100673	训练正确率: 96.02%	校验正确率: 97.08%
训练周期: 3 [32064/60000 (53%)]	Loss: 0.172493	训练正确率: 95.99%	校验正确率: 97.00%
训练周期: 3 [38464/60000 (64%)]	Loss: 0.187268	训练正确率: 95.98%	校验正确率: 97.26%
训练周期: 3 [44864/60000 (75%)]	Loss: 0.061519	训练正确率: 96.00%	校验正确率: 97.30%
训练周期: 3 [51264/60000 (85%)]	Loss: 0.073172	训练正确率: 96.07%	校验正确率: 96.98%
训练周期: 3 [57664/60000 (96%)]	Loss: 0.055029	训练正确率: 96.06%	校验正确率: 97.08%
训练周期: 4 [64/60000 (0%)]	Loss: 0.065538	训练正确率: 98.44%	校验正确率: 97.18%
训练周期: 4 [6464/60000 (11%)]	Loss: 0.039354	训练正确率: 96.49%	校验正确率: 97.44%
训练周期: 4 [12864/60000 (21%)]	Loss: 0.113937	训练正确率: 96.46%	校验正确率: 97.06%
训练周期: 4 [19264/60000 (32%)]	Loss: 0.088527	训练正确率: 96.46%	校验正确率: 97.26%
训练周期: 4 [25664/60000 (43%)]	Loss: 0.074021	训练正确率: 96.52%	校验正确率: 97.40%
训练周期: 4 [32064/60000 (53%)]	Loss: 0.161240	训练正确率: 96.51%	校验正确率: 97.46%
训练周期: 4 [38464/60000 (64%)]	Loss: 0.067512	训练正确率: 96.51%	校验正确率: 97.48%
训练周期: 4 [44864/60000 (75%)]	Loss: 0.118248	训练正确率: 96.44%	校验正确率: 97.46%
训练周期: 4 [51264/60000 (85%)]	Loss: 0.059335	训练正确率: 96.44%	校验正确率: 97.54%
训练周期: 4 [57664/60000 (96%)]	Loss: 0.059810	训练正确率: 96.48%	校验正确率: 97.48%
训练周期: 5 [64/60000 (0%)]	Loss: 0.116151	训练正确率: 96.88%	校验正确率: 97.38%
训练周期: 5 [6464/60000 (11%)]	Loss: 0.160240	训练正确率: 96.97%	校验正确率: 97.50%
训练周期: 5 [12864/60000 (21%)]	Loss: 0.272595	训练正确率: 96.76%	校验正确率: 97.48%
训练周期: 5 [19264/60000 (32%)]	Loss: 0.092262	训练正确率: 96.65%	校验正确率: 97.62%
训练周期: 5 [25664/60000 (43%)]	Loss: 0.061020	训练正确率: 96.64%	校验正确率: 97.62%
训练周期: 5 [32064/60000 (53%)]	Loss: 0.140946	训练正确率: 96.65%	校验正确率: 97.62%
训练周期: 5 [38464/60000 (64%)]	Loss: 0.172838	训练正确率: 96.69%	校验正确率: 97.60%
训练周期: 5 [44864/60000 (75%)]	Loss: 0.141658	训练正确率: 96.69%	校验正确率: 97.70%
训练周期: 5 [51264/60000 (85%)]	Loss: 0.102514	训练正确率: 96.72%	校验正确率: 97.80%
训练周期: 5 [57664/60000 (96%)]	Loss: 0.206360	训练正确率: 96.74%	校验正确率: 97.68%
训练周期: 6 [64/60000 (0%)]	Loss: 0.174180	训练正确率: 95.31%	校验正确率: 97.74%
训练周期: 6 [6464/60000 (11%)]	Loss: 0.192811	训练正确率: 96.41%	校验正确率: 97.80%
训练周期: 6 [12864/60000 (21%)]	Loss: 0.079086	训练正确率: 96.49%	校验正确率: 97.76%
训练周期: 6 [19264/60000 (32%)]	Loss: 0.094231	训练正确率: 96.68%	校验正确率: 97.68%
训练周期: 6 [25664/60000 (43%)]	Loss: 0.120969	训练正确率: 96.82%	校验正确率: 97.72%
训练周期: 6 [32064/60000 (53%)]	Loss: 0.039326	训练正确率: 96.90%	校验正确率: 97.78%
训练周期: 6 [38464/60000 (64%)]	Loss: 0.054882	训练正确率: 97.00%	校验正确率: 97.70%
训练周期: 6 [44864/60000 (75%)]	Loss: 0.161812	训练正确率: 97.01%	校验正确率: 97.76%
训练周期: 6 [51264/60000 (85%)]	Loss: 0.109295	训练正确率: 96.99%	校验正确率: 97.74%
训练周期: 6 [57664/60000 (96%)]	Loss: 0.092668	训练正确率: 97.02%	校验正确率: 97.68%
训练周期: 7 [64/60000 (0%)]	Loss: 0.072402	训练正确率: 98.44%	校验正确率: 97.88%
训练周期: 7 [6464/60000 (11%)]	Loss: 0.034500	训练正确率: 97.28%	校验正确率: 97.90%
训练周期: 7 [12864/60000 (21%)]	Loss: 0.032573	训练正确率: 97.16%	校验正确率: 97.88%
训练周期: 7 [19264/60000 (32%)]	Loss: 0.123999	训练正确率: 97.13%	校验正确率: 97.84%
训练周期: 7 [25664/60000 (43%)]	Loss: 0.116971	训练正确率: 97.16%	校验正确率: 97.82%
训练周期: 7 [32064/60000 (53%)]	Loss: 0.034661	训练正确率: 97.20%	校验正确率: 97.88%
训练周期: 7 [38464/60000 (64%)]	Loss: 0.060708	训练正确率: 97.20%	校验正确率: 97.94%
训练周期: 7 [44864/60000 (75%)]	Loss: 0.125912	训练正确率: 97.23%	校验正确率: 97.96%
训练周期: 7 [51264/60000 (85%)]	Loss: 0.186194	训练正确率: 97.22%	校验正确率: 97.98%
训练周期: 7 [57664/60000 (96%)]	Loss: 0.087361	训练正确率: 97.23%	校验正确率: 98.08%
训练周期: 8 [64/60000 (0%)]	Loss: 0.009403	训练正确率: 100.00%	校验正确率: 97.86%
训练周期: 8 [6464/60000 (11%)]	Loss: 0.326886	训练正确率: 97.42%	校验正确率: 98.00%
训练周期: 8 [12864/60000 (21%)]	Loss: 0.060366	训练正确率: 97.32%	校验正确率: 98.02%
训练周期: 8 [19264/60000 (32%)]	Loss: 0.161973	训练正确率: 97.31%	校验正确率: 98.10%
训练周期: 8 [25664/60000 (43%)]	Loss: 0.071160	训练正确率: 97.35%	校验正确率: 98.04%
训练周期: 8 [32064/60000 (53%)]	Loss: 0.017569	训练正确率: 97.37%	校验正确率: 98.10%
训练周期: 8 [38464/60000 (64%)]	Loss: 0.050663	训练正确率: 97.38%	校验正确率: 98.20%
训练周期: 8 [44864/60000 (75%)]	Loss: 0.222503	训练正确率: 97.37%	校验正确率: 98.18%
训练周期: 8 [51264/60000 (85%)]	Loss: 0.109463	训练正确率: 97.36%	校验正确率: 98.14%
训练周期: 8 [57664/60000 (96%)]	Loss: 0.080445	训练正确率: 97.35%	校验正确率: 98.10%
训练周期: 9 [64/60000 (0%)]	Loss: 0.017284	训练正确率: 100.00%	校验正确率: 98.06%
训练周期: 9 [6464/60000 (11%)]	Loss: 0.045179	训练正确率: 97.56%	校验正确率: 98.20%
训练周期: 9 [12864/60000 (21%)]	Loss: 0.031334	训练正确率: 97.56%	校验正确率: 98.04%
训练周期: 9 [19264/60000 (32%)]	Loss: 0.038351	训练正确率: 97.60%	校验正确率: 98.14%
训练周期: 9 [25664/60000 (43%)]	Loss: 0.121472	训练正确率: 97.59%	校验正确率: 98.30%
训练周期: 9 [32064/60000 (53%)]	Loss: 0.085770	训练正确率: 97.51%	校验正确率: 98.28%
训练周期: 9 [38464/60000 (64%)]	Loss: 0.059464	训练正确率: 97.48%	校验正确率: 98.32%
训练周期: 9 [44864/60000 (75%)]	Loss: 0.095663	训练正确率: 97.57%	校验正确率: 98.08%
训练周期: 9 [51264/60000 (85%)]	Loss: 0.113621	训练正确率: 97.54%	校验正确率: 98.34%
训练周期: 9 [57664/60000 (96%)]	Loss: 0.268711	训练正确率: 97.52%	校验正确率: 98.30%
训练周期: 10 [64/60000 (0%)]	Loss: 0.084389	训练正确率: 96.88%	校验正确率: 98.22%
训练周期: 10 [6464/60000 (11%)]	Loss: 0.055922	训练正确率: 97.62%	校验正确率: 98.34%
训练周期: 10 [12864/60000 (21%)]	Loss: 0.125814	训练正确率: 97.71%	校验正确率: 98.24%
训练周期: 10 [19264/60000 (32%)]	Loss: 0.045647	训练正确率: 97.67%	校验正确率: 98.18%
训练周期: 10 [25664/60000 (43%)]	Loss: 0.054552	训练正确率: 97.70%	校验正确率: 98.34%
训练周期: 10 [32064/60000 (53%)]	Loss: 0.017988	训练正确率: 97.65%	校验正确率: 98.42%
训练周期: 10 [38464/60000 (64%)]	Loss: 0.034680	训练正确率: 97.68%	校验正确率: 98.34%
训练周期: 10 [44864/60000 (75%)]	Loss: 0.068128	训练正确率: 97.68%	校验正确率: 98.24%
训练周期: 10 [51264/60000 (85%)]	Loss: 0.048264	训练正确率: 97.67%	校验正确率: 98.38%
训练周期: 10 [57664/60000 (96%)]	Loss: 0.040969	训练正确率: 97.68%	校验正确率: 98.32%
训练周期: 11 [64/60000 (0%)]	Loss: 0.095555	训练正确率: 98.44%	校验正确率: 98.28%
训练周期: 11 [6464/60000 (11%)]	Loss: 0.080019	训练正确率: 97.66%	校验正确率: 98.26%
训练周期: 11 [12864/60000 (21%)]	Loss: 0.099818	训练正确率: 97.64%	校验正确率: 98.36%
训练周期: 11 [19264/60000 (32%)]	Loss: 0.057005	训练正确率: 97.72%	校验正确率: 98.26%
训练周期: 11 [25664/60000 (43%)]	Loss: 0.032954	训练正确率: 97.76%	校验正确率: 98.54%
训练周期: 11 [32064/60000 (53%)]	Loss: 0.053869	训练正确率: 97.73%	校验正确率: 98.36%
训练周期: 11 [38464/60000 (64%)]	Loss: 0.062431	训练正确率: 97.70%	校验正确率: 98.30%
训练周期: 11 [44864/60000 (75%)]	Loss: 0.072020	训练正确率: 97.75%	校验正确率: 98.40%
训练周期: 11 [51264/60000 (85%)]	Loss: 0.010781	训练正确率: 97.75%	校验正确率: 98.48%
训练周期: 11 [57664/60000 (96%)]	Loss: 0.017668	训练正确率: 97.76%	校验正确率: 98.58%
训练周期: 12 [64/60000 (0%)]	Loss: 0.012428	训练正确率: 100.00%	校验正确率: 98.38%
训练周期: 12 [6464/60000 (11%)]	Loss: 0.055729	训练正确率: 98.21%	校验正确率: 98.46%
训练周期: 12 [12864/60000 (21%)]	Loss: 0.097194	训练正确率: 97.96%	校验正确率: 98.48%
训练周期: 12 [19264/60000 (32%)]	Loss: 0.071802	训练正确率: 97.87%	校验正确率: 98.32%
训练周期: 12 [25664/60000 (43%)]	Loss: 0.078902	训练正确率: 97.91%	校验正确率: 98.38%
训练周期: 12 [32064/60000 (53%)]	Loss: 0.247189	训练正确率: 97.88%	校验正确率: 98.54%
训练周期: 12 [38464/60000 (64%)]	Loss: 0.037298	训练正确率: 97.85%	校验正确率: 98.56%
训练周期: 12 [44864/60000 (75%)]	Loss: 0.026617	训练正确率: 97.88%	校验正确率: 98.48%
训练周期: 12 [51264/60000 (85%)]	Loss: 0.024428	训练正确率: 97.86%	校验正确率: 98.46%
训练周期: 12 [57664/60000 (96%)]	Loss: 0.017038	训练正确率: 97.86%	校验正确率: 98.52%
训练周期: 13 [64/60000 (0%)]	Loss: 0.067459	训练正确率: 98.44%	校验正确率: 98.60%
训练周期: 13 [6464/60000 (11%)]	Loss: 0.032820	训练正确率: 97.88%	校验正确率: 98.56%
训练周期: 13 [12864/60000 (21%)]	Loss: 0.009003	训练正确率: 98.03%	校验正确率: 98.60%
训练周期: 13 [19264/60000 (32%)]	Loss: 0.078738	训练正确率: 98.01%	校验正确率: 98.44%
训练周期: 13 [25664/60000 (43%)]	Loss: 0.182921	训练正确率: 97.94%	校验正确率: 98.46%
训练周期: 13 [32064/60000 (53%)]	Loss: 0.118050	训练正确率: 97.90%	校验正确率: 98.58%
训练周期: 13 [38464/60000 (64%)]	Loss: 0.024995	训练正确率: 97.90%	校验正确率: 98.52%
训练周期: 13 [44864/60000 (75%)]	Loss: 0.065351	训练正确率: 97.91%	校验正确率: 98.50%
训练周期: 13 [51264/60000 (85%)]	Loss: 0.037466	训练正确率: 97.94%	校验正确率: 98.50%
训练周期: 13 [57664/60000 (96%)]	Loss: 0.086423	训练正确率: 97.91%	校验正确率: 98.54%
训练周期: 14 [64/60000 (0%)]	Loss: 0.024936	训练正确率: 100.00%	校验正确率: 98.68%
训练周期: 14 [6464/60000 (11%)]	Loss: 0.035491	训练正确率: 98.24%	校验正确率: 98.44%
训练周期: 14 [12864/60000 (21%)]	Loss: 0.097156	训练正确率: 98.27%	校验正确率: 98.56%
训练周期: 14 [19264/60000 (32%)]	Loss: 0.163661	训练正确率: 98.22%	校验正确率: 98.60%
训练周期: 14 [25664/60000 (43%)]	Loss: 0.070727	训练正确率: 98.09%	校验正确率: 98.44%
训练周期: 14 [32064/60000 (53%)]	Loss: 0.006786	训练正确率: 98.10%	校验正确率: 98.42%
训练周期: 14 [38464/60000 (64%)]	Loss: 0.041515	训练正确率: 98.05%	校验正确率: 98.54%
训练周期: 14 [44864/60000 (75%)]	Loss: 0.009068	训练正确率: 98.05%	校验正确率: 98.40%
训练周期: 14 [51264/60000 (85%)]	Loss: 0.017752	训练正确率: 98.06%	校验正确率: 98.66%
训练周期: 14 [57664/60000 (96%)]	Loss: 0.070509	训练正确率: 98.03%	校验正确率: 98.58%
训练周期: 15 [64/60000 (0%)]	Loss: 0.026125	训练正确率: 98.44%	校验正确率: 98.58%
训练周期: 15 [6464/60000 (11%)]	Loss: 0.016736	训练正确率: 98.04%	校验正确率: 98.40%
训练周期: 15 [12864/60000 (21%)]	Loss: 0.035270	训练正确率: 97.87%	校验正确率: 98.48%
训练周期: 15 [19264/60000 (32%)]	Loss: 0.015142	训练正确率: 98.00%	校验正确率: 98.76%
训练周期: 15 [25664/60000 (43%)]	Loss: 0.131374	训练正确率: 98.00%	校验正确率: 98.56%
训练周期: 15 [32064/60000 (53%)]	Loss: 0.114820	训练正确率: 97.99%	校验正确率: 98.52%
训练周期: 15 [38464/60000 (64%)]	Loss: 0.066392	训练正确率: 97.97%	校验正确率: 98.56%
训练周期: 15 [44864/60000 (75%)]	Loss: 0.007783	训练正确率: 98.03%	校验正确率: 98.62%
训练周期: 15 [51264/60000 (85%)]	Loss: 0.025519	训练正确率: 98.02%	校验正确率: 98.74%
训练周期: 15 [57664/60000 (96%)]	Loss: 0.029439	训练正确率: 98.02%	校验正确率: 98.66%
训练周期: 16 [64/60000 (0%)]	Loss: 0.101036	训练正确率: 98.44%	校验正确率: 98.58%
训练周期: 16 [6464/60000 (11%)]	Loss: 0.091292	训练正确率: 98.04%	校验正确率: 98.52%
训练周期: 16 [12864/60000 (21%)]	Loss: 0.057263	训练正确率: 97.97%	校验正确率: 98.74%
训练周期: 16 [19264/60000 (32%)]	Loss: 0.012148	训练正确率: 98.12%	校验正确率: 98.72%
训练周期: 16 [25664/60000 (43%)]	Loss: 0.077445	训练正确率: 98.17%	校验正确率: 98.62%
训练周期: 16 [32064/60000 (53%)]	Loss: 0.026971	训练正确率: 98.18%	校验正确率: 98.72%
训练周期: 16 [38464/60000 (64%)]	Loss: 0.014428	训练正确率: 98.19%	校验正确率: 98.76%
训练周期: 16 [44864/60000 (75%)]	Loss: 0.048128	训练正确率: 98.17%	校验正确率: 98.72%
训练周期: 16 [51264/60000 (85%)]	Loss: 0.020594	训练正确率: 98.18%	校验正确率: 98.76%
训练周期: 16 [57664/60000 (96%)]	Loss: 0.048687	训练正确率: 98.14%	校验正确率: 98.66%
训练周期: 17 [64/60000 (0%)]	Loss: 0.116186	训练正确率: 98.44%	校验正确率: 98.68%
训练周期: 17 [6464/60000 (11%)]	Loss: 0.045756	训练正确率: 98.44%	校验正确率: 98.70%
训练周期: 17 [12864/60000 (21%)]	Loss: 0.093706	训练正确率: 98.31%	校验正确率: 98.82%
训练周期: 17 [19264/60000 (32%)]	Loss: 0.153490	训练正确率: 98.23%	校验正确率: 98.58%
训练周期: 17 [25664/60000 (43%)]	Loss: 0.037656	训练正确率: 98.26%	校验正确率: 98.64%
训练周期: 17 [32064/60000 (53%)]	Loss: 0.089268	训练正确率: 98.28%	校验正确率: 98.68%
训练周期: 17 [38464/60000 (64%)]	Loss: 0.007872	训练正确率: 98.24%	校验正确率: 98.78%
训练周期: 17 [44864/60000 (75%)]	Loss: 0.011014	训练正确率: 98.21%	校验正确率: 98.60%
训练周期: 17 [51264/60000 (85%)]	Loss: 0.082866	训练正确率: 98.20%	校验正确率: 98.72%
训练周期: 17 [57664/60000 (96%)]	Loss: 0.223709	训练正确率: 98.23%	校验正确率: 98.74%
训练周期: 18 [64/60000 (0%)]	Loss: 0.012267	训练正确率: 100.00%	校验正确率: 98.84%
训练周期: 18 [6464/60000 (11%)]	Loss: 0.016755	训练正确率: 98.24%	校验正确率: 98.74%
训练周期: 18 [12864/60000 (21%)]	Loss: 0.071003	训练正确率: 98.27%	校验正确率: 98.90%
训练周期: 18 [19264/60000 (32%)]	Loss: 0.021377	训练正确率: 98.37%	校验正确率: 98.78%
训练周期: 18 [25664/60000 (43%)]	Loss: 0.019434	训练正确率: 98.35%	校验正确率: 98.92%
训练周期: 18 [32064/60000 (53%)]	Loss: 0.046915	训练正确率: 98.30%	校验正确率: 98.76%
训练周期: 18 [38464/60000 (64%)]	Loss: 0.029202	训练正确率: 98.25%	校验正确率: 98.90%
训练周期: 18 [44864/60000 (75%)]	Loss: 0.089349	训练正确率: 98.28%	校验正确率: 98.86%
训练周期: 18 [51264/60000 (85%)]	Loss: 0.014138	训练正确率: 98.29%	校验正确率: 98.82%
训练周期: 18 [57664/60000 (96%)]	Loss: 0.059072	训练正确率: 98.27%	校验正确率: 98.76%
训练周期: 19 [64/60000 (0%)]	Loss: 0.064476	训练正确率: 95.31%	校验正确率: 98.80%
训练周期: 19 [6464/60000 (11%)]	Loss: 0.045951	训练正确率: 98.02%	校验正确率: 98.78%
训练周期: 19 [12864/60000 (21%)]	Loss: 0.009943	训练正确率: 98.28%	校验正确率: 98.88%
训练周期: 19 [19264/60000 (32%)]	Loss: 0.006517	训练正确率: 98.28%	校验正确率: 98.72%
训练周期: 19 [25664/60000 (43%)]	Loss: 0.117733	训练正确率: 98.32%	校验正确率: 99.00%
训练周期: 19 [32064/60000 (53%)]	Loss: 0.012346	训练正确率: 98.30%	校验正确率: 98.82%
训练周期: 19 [38464/60000 (64%)]	Loss: 0.004596	训练正确率: 98.30%	校验正确率: 98.86%
训练周期: 19 [44864/60000 (75%)]	Loss: 0.057226	训练正确率: 98.31%	校验正确率: 98.72%
训练周期: 19 [51264/60000 (85%)]	Loss: 0.045235	训练正确率: 98.33%	校验正确率: 99.00%
训练周期: 19 [57664/60000 (96%)]	Loss: 0.069116	训练正确率: 98.33%	校验正确率: 98.94%

In [11]

paddle.save(net.state_dict(), 'minst_conv_checkpoint.pdparams')
'''
PaddlePaddle:使用paddle.save()保存模型参数。
PyTorch:使用torch.save()保存模型参数。
'''
'\nPaddlePaddle:使用paddle.save()保存模型参数。\nPyTorch:使用torch.save()保存模型参数。\n'

In [12]

# 提取第一层卷积层的卷积核
plt.figure(figsize=(10, 7))
for i in range(4):
    plt.subplot(1, 4, i + 1)
    plt.imshow(net.conv1.weight.numpy()[i, 0, ...])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  if isinstance(obj, collections.Iterator):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  return list(data) if isinstance(data, collections.MappingView) else data
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_min = np.asscalar(a_min.astype(scaled_dtype))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
  a_max = np.asscalar(a_max.astype(scaled_dtype))

<Figure size 1000x700 with 4 Axes>

In [13]

# 绘制第二层的卷积核
plt.figure(figsize=(15, 10))
for i in range(4):
    for j in range(8):
        plt.subplot(4, 8, i * 8 + j + 1)
        plt.imshow(net.conv2.weight.numpy()[j, i, ...])

<Figure size 1500x1000 with 32 Axes>

猜你喜欢

转载自blog.csdn.net/m0_68036862/article/details/131484015