简介:
在这篇博客文章中,我们将深入探讨如何使用卷积神经网络(Convolutional Neural Networks,简称CNN)来实现手写数字识别。这是机器学习中一个非常基础且广泛应用的问题,理解和掌握其解决方法,对于入门深度学习非常有帮助。
我们将从介绍卷积神经网络的基本原理开始,然后详细讲解如何构建和训练一个简单的CNN模型,最后我们会展示如何使用这个模型来识别手写数字。在这个过程中,我们将详细解释每一步的原理和代码,使得读者能够完全理解并且能够自己动手实现。
无论你是一位刚刚接触机器学习的新手,还是一位有一定基础并希望进一步深入理解CNN的研究者,我们都相信你能在这篇文章中收获到新的知识和启发。让我们一起开始这次探索之旅,感受卷积神经网络在手写数字识别上的强大之处!
# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory.
# This directory will be recovered automatically after resetting environment.
!ls /home/aistudio/data
In [ ]
# 查看工作区文件, 该目录下的变更将会持久保存. 请及时清理不必要的文件, 避免加载过慢.
# View personal work directory.
# All changes under this directory will be kept even after reset.
# Please clean unnecessary files in time to speed up environment loading.
!ls /home/aistudio/work
In [ ]
# # 如果需要进行持久化安装, 需要使用持久化路径, 如下方代码示例:
# # If a persistence installation is required,
# # you need to use the persistence path as the following:
# !mkdir /home/aistudio/external-libraries
# !pip install beautifulsoup4 -t /home/aistudio/external-libraries
mkdir: cannot create directory ‘/home/aistudio/external-libraries’: File exists Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting beautifulsoup4 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/57/f4/a69c20ee4f660081a7dedb1ac57f29be9378e04edfcb90c526b923d4bebc/beautifulsoup4-4.12.2-py3-none-any.whl (142 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 143.0/143.0 kB 7.6 MB/s eta 0:00:00 Collecting soupsieve>1.2 Downloading https://pypi.tuna.tsinghua.edu.cn/packages/49/37/673d6490efc51ec46d198c75903d99de59baffdd47aea3d071b80a9e4e89/soupsieve-2.4.1-py3-none-any.whl (36 kB) Installing collected packages: soupsieve, beautifulsoup4 Successfully installed beautifulsoup4-4.12.2 soupsieve-2.4.1 WARNING: Target directory /home/aistudio/external-libraries/bs4 already exists. Specify --upgrade to force replacement. WARNING: Target directory /home/aistudio/external-libraries/soupsieve-2.4.1.dist-info already exists. Specify --upgrade to force replacement. WARNING: Target directory /home/aistudio/external-libraries/beautifulsoup4-4.12.2.dist-info already exists. Specify --upgrade to force replacement. WARNING: Target directory /home/aistudio/external-libraries/soupsieve already exists. Specify --upgrade to force replacement. [notice] A new release of pip available: 22.1.2 -> 23.1.2 [notice] To update, run: pip install --upgrade pip
In [ ]
# 同时添加如下代码, 这样每次环境(kernel)启动的时候只要运行下方代码即可:
# Also add the following code,
# so that every time the environment (kernel) starts,
# just run the following code:
import sys
sys.path.append('/home/aistudio/external-libraries')
In [ ]
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
import paddle.nn.functional as F
from paddle.vision.datasets import MNIST
import paddle.vision.transforms as transforms
import paddle.vision.models as models
import matplotlib.pyplot as plt
import numpy as np
# paddlepaddle的导入
'''
PaddlePaddle: import paddle ; import paddle.vision.transforms as transforms...
PyTorch: import torch ; import torchvision.transforms as transforms...
'''
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import MutableMapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Iterable, Mapping /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working from collections import Sized
In [ ]
image_size = 28
num_classes = 10
num_epochs = 20
batch_size = 64
learning_rate = 0.001
train_dataset = paddle.vision.datasets.MNIST(mode='train', transform=transforms.ToTensor())
test_dataset = paddle.vision.datasets.MNIST(mode='test', transform=transforms.ToTensor())
'''
PaddlePaddle:使用paddle.vision.datasets.MNIST加载MNIST数据集。
PyTorch:使用torchvision.datasets.MNIST加载MNIST数据集。
'''
# Generate random indices for validation and test sets
permutes = np.random.permutation(range(len(train_dataset)))
indices_val = permutes[:5000]
indices_test = permutes[5000:]
# Create Subset datasets for validation and test sets
val_dataset = paddle.io.Subset(train_dataset, indices_val)
test_dataset = paddle.io.Subset(train_dataset, indices_test)
'''
PaddlePaddle:使用paddle.io.Subset创建子集数据集。
PyTorch:
sampler_val = torch.utils.data.sampler.SubsetRandomSampler(indices_val)
sampler_test = torch.utils.data.sampler.SubsetRandomSampler(indices_test)
'''
# Create data loaders for train, validation, and test sets
train_loader = paddle.io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = paddle.io.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = paddle.io.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
'''
PaddlePaddle:使用paddle.io.DataLoader创建
PyTorch:使用torch.utils.data.DataLoader创建
'''
In [ ]
depth = [4, 8]
class ConvNet(nn.Layer):
'''
PaddlePaddle:使用paddle.nn.Layer作为基类
PyTorch:使用torch.nn.Module作为基类
'''
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2D(1, 4, 5, padding=2)
self.pool = nn.MaxPool2D(2, 2)
self.conv2 = nn.Conv2D(depth[0], depth[1], 5, padding=2)
self.fc1 = nn.Linear(image_size // 4 * image_size // 4 * 8, 512)
self.fc2 = nn.Linear(512, num_classes)
'''
PaddlePaddle:使用paddle.nn模块下的层定义,如paddle.nn.Conv2D
PyTorch:使用torch.nn模块下的层定义,如torch.nn.Conv2d
'''
def forward(self, x):
x = F.relu(self.conv1(x))
'''
PaddlePaddle:使用paddle.nn.functional模块下的函数
PyTorch:使用torch.nn.functional模块下的函数
'''
x = self.pool(x)
x = F.relu(self.conv2(x))
x = self.pool(x)
x = paddle.flatten(x, start_axis=1, stop_axis=-1)
'''
PaddlePaddle:使用paddle.flatten函数
PyTorch:使用torch.view函数
'''
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.5, training=self.training)
x = self.fc2(x)
x = F.log_softmax(x, axis=1)
return x
In [ ]
def rightness(predictions, labels):
pred = paddle.argmax(predictions, axis=1) #paddle.argmx
rights = paddle.sum(pred == labels) #paddle.sum
return rights, len(labels)
In [ ]
net = ConvNet()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Momentum(learning_rate=0.001, momentum=0.9, parameters=net.parameters())
'''
PaddlePaddle:使用paddle.nn.CrossEntropyLoss()
PyTorch:使用torch.nn.CrossEntropyLoss()
PaddlePaddle:使用paddle.optimizer.Momentum()定义优化器。
PyTorch:使用torch.optim.SGD定义优化器
'''
In [10]
record = [] # 记录准确率等数值的容器
weights = [] # 每若干步就记录一次卷积核
# 开始训练循环
for epoch in range(num_epochs):
train_rights = [] # 记录训练数据集准确率的容器
for batch_idx, (data, target) in enumerate(train_loader): # 针对容器中的每一个批进行循环
data = paddle.to_tensor(data)# PaddlePaddle:使用paddle.to_tensor()转换为PaddlePaddle张量。
target = paddle.to_tensor(target)
target = paddle.squeeze(target, axis=1) # 调整目标维度与预测输出相匹配
data.stop_gradient = False
target.stop_gradient = False
net.train() # 给网络模型做标记,表示模型在训练集上训练
output = net(data) # 完成一次预测
loss = criterion(output, target) # 计算误差
optimizer.clear_grad() # 清空梯度
'''
PaddlePaddle:使用optimizer.clear_grad()清空梯度。
PyTorch:使用optimizer.zero_grad()清空梯度。
'''
loss.backward() # 反向传播
optimizer.step() # 一步随机梯度下降
pred = paddle.argmax(output, axis=1)
rights = paddle.sum(paddle.cast(pred == target, dtype='int32'))
train_rights.append((rights.numpy()[0], data.shape[0]))
if batch_idx % 100 == 0: # 每间隔100个batch执行一次
# train_r为一个二元组,分别记录训练集中分类正确的数量和该集合中总的样本数
train_r = (np.sum([tup[0] for tup in train_rights]), np.sum([tup[1] for tup in train_rights]))
# print(train_r)
net.eval() # 给网络模型做标记,表示模型在训练集上评估
val_rights = [] # 记录校验数据集准确率的容器
for (data, target) in val_loader:
data = paddle.to_tensor(data)
target = paddle.to_tensor(target)
target = paddle.squeeze(target, axis=1)
output = net(data) # 完成一次预测
pred = paddle.argmax(output, axis=1)
rights = paddle.sum(paddle.cast(pred == target, dtype='int32'))
# print(rights)
val_rights.append((rights.numpy()[0], data.shape[0]))
# val_r为一个二元组,分别记录校验集中分类正确的数量和该集合中总的样本数
val_r = (np.sum([tup[0] for tup in val_rights]), np.sum([tup[1] for tup in val_rights]))
# print(val_r)
# 打印准确率等数值,其中正确率为本训练周期Epoch开始后到目前批次的正确率的平均值
train_accuracy = 100. * train_r[0] / train_r[1]
val_accuracy = 100. * val_r[0] / val_r[1]
print('训练周期: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t训练正确率: {:.2f}%\t校验正确率: {:.2f}%'.format(
epoch, train_r[1], len(train_loader.dataset),
100. * batch_idx / len(train_loader),
loss.item(),
train_accuracy,
val_accuracy))
# 将准确率和权重等数值加载到容器中,以方便后续处理
record.append((100. - train_accuracy, 100. - val_accuracy))
weights.append([param.numpy() for param in net.parameters()])
'''
PaddlePaddle:使用PaddlePaddle的函数,如paddle.argmax()等
PyTorch:使用PyTorch的函数,如torch.max()等
'''
训练周期: 0 [64/60000 (0%)] Loss: 0.294312 训练正确率: 90.62% 校验正确率: 94.04% 训练周期: 0 [6464/60000 (11%)] Loss: 0.200471 训练正确率: 93.04% 校验正确率: 94.36% 训练周期: 0 [12864/60000 (21%)] Loss: 0.172873 训练正确率: 93.14% 校验正确率: 94.64% 训练周期: 0 [19264/60000 (32%)] Loss: 0.267028 训练正确率: 93.29% 校验正确率: 94.92% 训练周期: 0 [25664/60000 (43%)] Loss: 0.082636 训练正确率: 93.30% 校验正确率: 95.18% 训练周期: 0 [32064/60000 (53%)] Loss: 0.197915 训练正确率: 93.43% 校验正确率: 95.24% 训练周期: 0 [38464/60000 (64%)] Loss: 0.189998 训练正确率: 93.49% 校验正确率: 95.32% 训练周期: 0 [44864/60000 (75%)] Loss: 0.223470 训练正确率: 93.55% 校验正确率: 95.44% 训练周期: 0 [51264/60000 (85%)] Loss: 0.112054 训练正确率: 93.60% 校验正确率: 95.62% 训练周期: 0 [57664/60000 (96%)] Loss: 0.045446 训练正确率: 93.71% 校验正确率: 95.96% 训练周期: 1 [64/60000 (0%)] Loss: 0.162490 训练正确率: 96.88% 校验正确率: 95.80% 训练周期: 1 [6464/60000 (11%)] Loss: 0.154191 训练正确率: 94.46% 校验正确率: 95.98% 训练周期: 1 [12864/60000 (21%)] Loss: 0.163062 训练正确率: 94.61% 校验正确率: 96.10% 训练周期: 1 [19264/60000 (32%)] Loss: 0.098654 训练正确率: 94.65% 校验正确率: 96.14% 训练周期: 1 [25664/60000 (43%)] Loss: 0.142942 训练正确率: 94.71% 校验正确率: 96.04% 训练周期: 1 [32064/60000 (53%)] Loss: 0.058082 训练正确率: 94.79% 校验正确率: 96.24% 训练周期: 1 [38464/60000 (64%)] Loss: 0.056331 训练正确率: 94.83% 校验正确率: 96.42% 训练周期: 1 [44864/60000 (75%)] Loss: 0.296984 训练正确率: 94.88% 校验正确率: 96.38% 训练周期: 1 [51264/60000 (85%)] Loss: 0.181992 训练正确率: 94.89% 校验正确率: 96.34% 训练周期: 1 [57664/60000 (96%)] Loss: 0.269135 训练正确率: 94.90% 校验正确率: 96.48% 训练周期: 2 [64/60000 (0%)] Loss: 0.245220 训练正确率: 93.75% 校验正确率: 96.32% 训练周期: 2 [6464/60000 (11%)] Loss: 0.195959 训练正确率: 95.41% 校验正确率: 96.46% 训练周期: 2 [12864/60000 (21%)] Loss: 0.076648 训练正确率: 95.40% 校验正确率: 96.56% 训练周期: 2 [19264/60000 (32%)] Loss: 0.103266 训练正确率: 95.38% 校验正确率: 96.56% 训练周期: 2 [25664/60000 (43%)] Loss: 0.326429 训练正确率: 95.34% 校验正确率: 96.68% 训练周期: 2 [32064/60000 (53%)] Loss: 0.164600 训练正确率: 95.46% 校验正确率: 96.64% 训练周期: 2 [38464/60000 (64%)] Loss: 0.100820 训练正确率: 95.44% 校验正确率: 96.80% 训练周期: 2 [44864/60000 (75%)] Loss: 0.082001 训练正确率: 95.48% 校验正确率: 96.82% 训练周期: 2 [51264/60000 (85%)] Loss: 0.119786 训练正确率: 95.52% 校验正确率: 96.82% 训练周期: 2 [57664/60000 (96%)] Loss: 0.237487 训练正确率: 95.53% 校验正确率: 96.94% 训练周期: 3 [64/60000 (0%)] Loss: 0.058802 训练正确率: 100.00% 校验正确率: 96.78% 训练周期: 3 [6464/60000 (11%)] Loss: 0.155010 训练正确率: 96.09% 校验正确率: 97.06% 训练周期: 3 [12864/60000 (21%)] Loss: 0.042440 训练正确率: 95.99% 校验正确率: 96.92% 训练周期: 3 [19264/60000 (32%)] Loss: 0.118062 训练正确率: 95.98% 校验正确率: 97.02% 训练周期: 3 [25664/60000 (43%)] Loss: 0.100673 训练正确率: 96.02% 校验正确率: 97.08% 训练周期: 3 [32064/60000 (53%)] Loss: 0.172493 训练正确率: 95.99% 校验正确率: 97.00% 训练周期: 3 [38464/60000 (64%)] Loss: 0.187268 训练正确率: 95.98% 校验正确率: 97.26% 训练周期: 3 [44864/60000 (75%)] Loss: 0.061519 训练正确率: 96.00% 校验正确率: 97.30% 训练周期: 3 [51264/60000 (85%)] Loss: 0.073172 训练正确率: 96.07% 校验正确率: 96.98% 训练周期: 3 [57664/60000 (96%)] Loss: 0.055029 训练正确率: 96.06% 校验正确率: 97.08% 训练周期: 4 [64/60000 (0%)] Loss: 0.065538 训练正确率: 98.44% 校验正确率: 97.18% 训练周期: 4 [6464/60000 (11%)] Loss: 0.039354 训练正确率: 96.49% 校验正确率: 97.44% 训练周期: 4 [12864/60000 (21%)] Loss: 0.113937 训练正确率: 96.46% 校验正确率: 97.06% 训练周期: 4 [19264/60000 (32%)] Loss: 0.088527 训练正确率: 96.46% 校验正确率: 97.26% 训练周期: 4 [25664/60000 (43%)] Loss: 0.074021 训练正确率: 96.52% 校验正确率: 97.40% 训练周期: 4 [32064/60000 (53%)] Loss: 0.161240 训练正确率: 96.51% 校验正确率: 97.46% 训练周期: 4 [38464/60000 (64%)] Loss: 0.067512 训练正确率: 96.51% 校验正确率: 97.48% 训练周期: 4 [44864/60000 (75%)] Loss: 0.118248 训练正确率: 96.44% 校验正确率: 97.46% 训练周期: 4 [51264/60000 (85%)] Loss: 0.059335 训练正确率: 96.44% 校验正确率: 97.54% 训练周期: 4 [57664/60000 (96%)] Loss: 0.059810 训练正确率: 96.48% 校验正确率: 97.48% 训练周期: 5 [64/60000 (0%)] Loss: 0.116151 训练正确率: 96.88% 校验正确率: 97.38% 训练周期: 5 [6464/60000 (11%)] Loss: 0.160240 训练正确率: 96.97% 校验正确率: 97.50% 训练周期: 5 [12864/60000 (21%)] Loss: 0.272595 训练正确率: 96.76% 校验正确率: 97.48% 训练周期: 5 [19264/60000 (32%)] Loss: 0.092262 训练正确率: 96.65% 校验正确率: 97.62% 训练周期: 5 [25664/60000 (43%)] Loss: 0.061020 训练正确率: 96.64% 校验正确率: 97.62% 训练周期: 5 [32064/60000 (53%)] Loss: 0.140946 训练正确率: 96.65% 校验正确率: 97.62% 训练周期: 5 [38464/60000 (64%)] Loss: 0.172838 训练正确率: 96.69% 校验正确率: 97.60% 训练周期: 5 [44864/60000 (75%)] Loss: 0.141658 训练正确率: 96.69% 校验正确率: 97.70% 训练周期: 5 [51264/60000 (85%)] Loss: 0.102514 训练正确率: 96.72% 校验正确率: 97.80% 训练周期: 5 [57664/60000 (96%)] Loss: 0.206360 训练正确率: 96.74% 校验正确率: 97.68% 训练周期: 6 [64/60000 (0%)] Loss: 0.174180 训练正确率: 95.31% 校验正确率: 97.74% 训练周期: 6 [6464/60000 (11%)] Loss: 0.192811 训练正确率: 96.41% 校验正确率: 97.80% 训练周期: 6 [12864/60000 (21%)] Loss: 0.079086 训练正确率: 96.49% 校验正确率: 97.76% 训练周期: 6 [19264/60000 (32%)] Loss: 0.094231 训练正确率: 96.68% 校验正确率: 97.68% 训练周期: 6 [25664/60000 (43%)] Loss: 0.120969 训练正确率: 96.82% 校验正确率: 97.72% 训练周期: 6 [32064/60000 (53%)] Loss: 0.039326 训练正确率: 96.90% 校验正确率: 97.78% 训练周期: 6 [38464/60000 (64%)] Loss: 0.054882 训练正确率: 97.00% 校验正确率: 97.70% 训练周期: 6 [44864/60000 (75%)] Loss: 0.161812 训练正确率: 97.01% 校验正确率: 97.76% 训练周期: 6 [51264/60000 (85%)] Loss: 0.109295 训练正确率: 96.99% 校验正确率: 97.74% 训练周期: 6 [57664/60000 (96%)] Loss: 0.092668 训练正确率: 97.02% 校验正确率: 97.68% 训练周期: 7 [64/60000 (0%)] Loss: 0.072402 训练正确率: 98.44% 校验正确率: 97.88% 训练周期: 7 [6464/60000 (11%)] Loss: 0.034500 训练正确率: 97.28% 校验正确率: 97.90% 训练周期: 7 [12864/60000 (21%)] Loss: 0.032573 训练正确率: 97.16% 校验正确率: 97.88% 训练周期: 7 [19264/60000 (32%)] Loss: 0.123999 训练正确率: 97.13% 校验正确率: 97.84% 训练周期: 7 [25664/60000 (43%)] Loss: 0.116971 训练正确率: 97.16% 校验正确率: 97.82% 训练周期: 7 [32064/60000 (53%)] Loss: 0.034661 训练正确率: 97.20% 校验正确率: 97.88% 训练周期: 7 [38464/60000 (64%)] Loss: 0.060708 训练正确率: 97.20% 校验正确率: 97.94% 训练周期: 7 [44864/60000 (75%)] Loss: 0.125912 训练正确率: 97.23% 校验正确率: 97.96% 训练周期: 7 [51264/60000 (85%)] Loss: 0.186194 训练正确率: 97.22% 校验正确率: 97.98% 训练周期: 7 [57664/60000 (96%)] Loss: 0.087361 训练正确率: 97.23% 校验正确率: 98.08% 训练周期: 8 [64/60000 (0%)] Loss: 0.009403 训练正确率: 100.00% 校验正确率: 97.86% 训练周期: 8 [6464/60000 (11%)] Loss: 0.326886 训练正确率: 97.42% 校验正确率: 98.00% 训练周期: 8 [12864/60000 (21%)] Loss: 0.060366 训练正确率: 97.32% 校验正确率: 98.02% 训练周期: 8 [19264/60000 (32%)] Loss: 0.161973 训练正确率: 97.31% 校验正确率: 98.10% 训练周期: 8 [25664/60000 (43%)] Loss: 0.071160 训练正确率: 97.35% 校验正确率: 98.04% 训练周期: 8 [32064/60000 (53%)] Loss: 0.017569 训练正确率: 97.37% 校验正确率: 98.10% 训练周期: 8 [38464/60000 (64%)] Loss: 0.050663 训练正确率: 97.38% 校验正确率: 98.20% 训练周期: 8 [44864/60000 (75%)] Loss: 0.222503 训练正确率: 97.37% 校验正确率: 98.18% 训练周期: 8 [51264/60000 (85%)] Loss: 0.109463 训练正确率: 97.36% 校验正确率: 98.14% 训练周期: 8 [57664/60000 (96%)] Loss: 0.080445 训练正确率: 97.35% 校验正确率: 98.10% 训练周期: 9 [64/60000 (0%)] Loss: 0.017284 训练正确率: 100.00% 校验正确率: 98.06% 训练周期: 9 [6464/60000 (11%)] Loss: 0.045179 训练正确率: 97.56% 校验正确率: 98.20% 训练周期: 9 [12864/60000 (21%)] Loss: 0.031334 训练正确率: 97.56% 校验正确率: 98.04% 训练周期: 9 [19264/60000 (32%)] Loss: 0.038351 训练正确率: 97.60% 校验正确率: 98.14% 训练周期: 9 [25664/60000 (43%)] Loss: 0.121472 训练正确率: 97.59% 校验正确率: 98.30% 训练周期: 9 [32064/60000 (53%)] Loss: 0.085770 训练正确率: 97.51% 校验正确率: 98.28% 训练周期: 9 [38464/60000 (64%)] Loss: 0.059464 训练正确率: 97.48% 校验正确率: 98.32% 训练周期: 9 [44864/60000 (75%)] Loss: 0.095663 训练正确率: 97.57% 校验正确率: 98.08% 训练周期: 9 [51264/60000 (85%)] Loss: 0.113621 训练正确率: 97.54% 校验正确率: 98.34% 训练周期: 9 [57664/60000 (96%)] Loss: 0.268711 训练正确率: 97.52% 校验正确率: 98.30% 训练周期: 10 [64/60000 (0%)] Loss: 0.084389 训练正确率: 96.88% 校验正确率: 98.22% 训练周期: 10 [6464/60000 (11%)] Loss: 0.055922 训练正确率: 97.62% 校验正确率: 98.34% 训练周期: 10 [12864/60000 (21%)] Loss: 0.125814 训练正确率: 97.71% 校验正确率: 98.24% 训练周期: 10 [19264/60000 (32%)] Loss: 0.045647 训练正确率: 97.67% 校验正确率: 98.18% 训练周期: 10 [25664/60000 (43%)] Loss: 0.054552 训练正确率: 97.70% 校验正确率: 98.34% 训练周期: 10 [32064/60000 (53%)] Loss: 0.017988 训练正确率: 97.65% 校验正确率: 98.42% 训练周期: 10 [38464/60000 (64%)] Loss: 0.034680 训练正确率: 97.68% 校验正确率: 98.34% 训练周期: 10 [44864/60000 (75%)] Loss: 0.068128 训练正确率: 97.68% 校验正确率: 98.24% 训练周期: 10 [51264/60000 (85%)] Loss: 0.048264 训练正确率: 97.67% 校验正确率: 98.38% 训练周期: 10 [57664/60000 (96%)] Loss: 0.040969 训练正确率: 97.68% 校验正确率: 98.32% 训练周期: 11 [64/60000 (0%)] Loss: 0.095555 训练正确率: 98.44% 校验正确率: 98.28% 训练周期: 11 [6464/60000 (11%)] Loss: 0.080019 训练正确率: 97.66% 校验正确率: 98.26% 训练周期: 11 [12864/60000 (21%)] Loss: 0.099818 训练正确率: 97.64% 校验正确率: 98.36% 训练周期: 11 [19264/60000 (32%)] Loss: 0.057005 训练正确率: 97.72% 校验正确率: 98.26% 训练周期: 11 [25664/60000 (43%)] Loss: 0.032954 训练正确率: 97.76% 校验正确率: 98.54% 训练周期: 11 [32064/60000 (53%)] Loss: 0.053869 训练正确率: 97.73% 校验正确率: 98.36% 训练周期: 11 [38464/60000 (64%)] Loss: 0.062431 训练正确率: 97.70% 校验正确率: 98.30% 训练周期: 11 [44864/60000 (75%)] Loss: 0.072020 训练正确率: 97.75% 校验正确率: 98.40% 训练周期: 11 [51264/60000 (85%)] Loss: 0.010781 训练正确率: 97.75% 校验正确率: 98.48% 训练周期: 11 [57664/60000 (96%)] Loss: 0.017668 训练正确率: 97.76% 校验正确率: 98.58% 训练周期: 12 [64/60000 (0%)] Loss: 0.012428 训练正确率: 100.00% 校验正确率: 98.38% 训练周期: 12 [6464/60000 (11%)] Loss: 0.055729 训练正确率: 98.21% 校验正确率: 98.46% 训练周期: 12 [12864/60000 (21%)] Loss: 0.097194 训练正确率: 97.96% 校验正确率: 98.48% 训练周期: 12 [19264/60000 (32%)] Loss: 0.071802 训练正确率: 97.87% 校验正确率: 98.32% 训练周期: 12 [25664/60000 (43%)] Loss: 0.078902 训练正确率: 97.91% 校验正确率: 98.38% 训练周期: 12 [32064/60000 (53%)] Loss: 0.247189 训练正确率: 97.88% 校验正确率: 98.54% 训练周期: 12 [38464/60000 (64%)] Loss: 0.037298 训练正确率: 97.85% 校验正确率: 98.56% 训练周期: 12 [44864/60000 (75%)] Loss: 0.026617 训练正确率: 97.88% 校验正确率: 98.48% 训练周期: 12 [51264/60000 (85%)] Loss: 0.024428 训练正确率: 97.86% 校验正确率: 98.46% 训练周期: 12 [57664/60000 (96%)] Loss: 0.017038 训练正确率: 97.86% 校验正确率: 98.52% 训练周期: 13 [64/60000 (0%)] Loss: 0.067459 训练正确率: 98.44% 校验正确率: 98.60% 训练周期: 13 [6464/60000 (11%)] Loss: 0.032820 训练正确率: 97.88% 校验正确率: 98.56% 训练周期: 13 [12864/60000 (21%)] Loss: 0.009003 训练正确率: 98.03% 校验正确率: 98.60% 训练周期: 13 [19264/60000 (32%)] Loss: 0.078738 训练正确率: 98.01% 校验正确率: 98.44% 训练周期: 13 [25664/60000 (43%)] Loss: 0.182921 训练正确率: 97.94% 校验正确率: 98.46% 训练周期: 13 [32064/60000 (53%)] Loss: 0.118050 训练正确率: 97.90% 校验正确率: 98.58% 训练周期: 13 [38464/60000 (64%)] Loss: 0.024995 训练正确率: 97.90% 校验正确率: 98.52% 训练周期: 13 [44864/60000 (75%)] Loss: 0.065351 训练正确率: 97.91% 校验正确率: 98.50% 训练周期: 13 [51264/60000 (85%)] Loss: 0.037466 训练正确率: 97.94% 校验正确率: 98.50% 训练周期: 13 [57664/60000 (96%)] Loss: 0.086423 训练正确率: 97.91% 校验正确率: 98.54% 训练周期: 14 [64/60000 (0%)] Loss: 0.024936 训练正确率: 100.00% 校验正确率: 98.68% 训练周期: 14 [6464/60000 (11%)] Loss: 0.035491 训练正确率: 98.24% 校验正确率: 98.44% 训练周期: 14 [12864/60000 (21%)] Loss: 0.097156 训练正确率: 98.27% 校验正确率: 98.56% 训练周期: 14 [19264/60000 (32%)] Loss: 0.163661 训练正确率: 98.22% 校验正确率: 98.60% 训练周期: 14 [25664/60000 (43%)] Loss: 0.070727 训练正确率: 98.09% 校验正确率: 98.44% 训练周期: 14 [32064/60000 (53%)] Loss: 0.006786 训练正确率: 98.10% 校验正确率: 98.42% 训练周期: 14 [38464/60000 (64%)] Loss: 0.041515 训练正确率: 98.05% 校验正确率: 98.54% 训练周期: 14 [44864/60000 (75%)] Loss: 0.009068 训练正确率: 98.05% 校验正确率: 98.40% 训练周期: 14 [51264/60000 (85%)] Loss: 0.017752 训练正确率: 98.06% 校验正确率: 98.66% 训练周期: 14 [57664/60000 (96%)] Loss: 0.070509 训练正确率: 98.03% 校验正确率: 98.58% 训练周期: 15 [64/60000 (0%)] Loss: 0.026125 训练正确率: 98.44% 校验正确率: 98.58% 训练周期: 15 [6464/60000 (11%)] Loss: 0.016736 训练正确率: 98.04% 校验正确率: 98.40% 训练周期: 15 [12864/60000 (21%)] Loss: 0.035270 训练正确率: 97.87% 校验正确率: 98.48% 训练周期: 15 [19264/60000 (32%)] Loss: 0.015142 训练正确率: 98.00% 校验正确率: 98.76% 训练周期: 15 [25664/60000 (43%)] Loss: 0.131374 训练正确率: 98.00% 校验正确率: 98.56% 训练周期: 15 [32064/60000 (53%)] Loss: 0.114820 训练正确率: 97.99% 校验正确率: 98.52% 训练周期: 15 [38464/60000 (64%)] Loss: 0.066392 训练正确率: 97.97% 校验正确率: 98.56% 训练周期: 15 [44864/60000 (75%)] Loss: 0.007783 训练正确率: 98.03% 校验正确率: 98.62% 训练周期: 15 [51264/60000 (85%)] Loss: 0.025519 训练正确率: 98.02% 校验正确率: 98.74% 训练周期: 15 [57664/60000 (96%)] Loss: 0.029439 训练正确率: 98.02% 校验正确率: 98.66% 训练周期: 16 [64/60000 (0%)] Loss: 0.101036 训练正确率: 98.44% 校验正确率: 98.58% 训练周期: 16 [6464/60000 (11%)] Loss: 0.091292 训练正确率: 98.04% 校验正确率: 98.52% 训练周期: 16 [12864/60000 (21%)] Loss: 0.057263 训练正确率: 97.97% 校验正确率: 98.74% 训练周期: 16 [19264/60000 (32%)] Loss: 0.012148 训练正确率: 98.12% 校验正确率: 98.72% 训练周期: 16 [25664/60000 (43%)] Loss: 0.077445 训练正确率: 98.17% 校验正确率: 98.62% 训练周期: 16 [32064/60000 (53%)] Loss: 0.026971 训练正确率: 98.18% 校验正确率: 98.72% 训练周期: 16 [38464/60000 (64%)] Loss: 0.014428 训练正确率: 98.19% 校验正确率: 98.76% 训练周期: 16 [44864/60000 (75%)] Loss: 0.048128 训练正确率: 98.17% 校验正确率: 98.72% 训练周期: 16 [51264/60000 (85%)] Loss: 0.020594 训练正确率: 98.18% 校验正确率: 98.76% 训练周期: 16 [57664/60000 (96%)] Loss: 0.048687 训练正确率: 98.14% 校验正确率: 98.66% 训练周期: 17 [64/60000 (0%)] Loss: 0.116186 训练正确率: 98.44% 校验正确率: 98.68% 训练周期: 17 [6464/60000 (11%)] Loss: 0.045756 训练正确率: 98.44% 校验正确率: 98.70% 训练周期: 17 [12864/60000 (21%)] Loss: 0.093706 训练正确率: 98.31% 校验正确率: 98.82% 训练周期: 17 [19264/60000 (32%)] Loss: 0.153490 训练正确率: 98.23% 校验正确率: 98.58% 训练周期: 17 [25664/60000 (43%)] Loss: 0.037656 训练正确率: 98.26% 校验正确率: 98.64% 训练周期: 17 [32064/60000 (53%)] Loss: 0.089268 训练正确率: 98.28% 校验正确率: 98.68% 训练周期: 17 [38464/60000 (64%)] Loss: 0.007872 训练正确率: 98.24% 校验正确率: 98.78% 训练周期: 17 [44864/60000 (75%)] Loss: 0.011014 训练正确率: 98.21% 校验正确率: 98.60% 训练周期: 17 [51264/60000 (85%)] Loss: 0.082866 训练正确率: 98.20% 校验正确率: 98.72% 训练周期: 17 [57664/60000 (96%)] Loss: 0.223709 训练正确率: 98.23% 校验正确率: 98.74% 训练周期: 18 [64/60000 (0%)] Loss: 0.012267 训练正确率: 100.00% 校验正确率: 98.84% 训练周期: 18 [6464/60000 (11%)] Loss: 0.016755 训练正确率: 98.24% 校验正确率: 98.74% 训练周期: 18 [12864/60000 (21%)] Loss: 0.071003 训练正确率: 98.27% 校验正确率: 98.90% 训练周期: 18 [19264/60000 (32%)] Loss: 0.021377 训练正确率: 98.37% 校验正确率: 98.78% 训练周期: 18 [25664/60000 (43%)] Loss: 0.019434 训练正确率: 98.35% 校验正确率: 98.92% 训练周期: 18 [32064/60000 (53%)] Loss: 0.046915 训练正确率: 98.30% 校验正确率: 98.76% 训练周期: 18 [38464/60000 (64%)] Loss: 0.029202 训练正确率: 98.25% 校验正确率: 98.90% 训练周期: 18 [44864/60000 (75%)] Loss: 0.089349 训练正确率: 98.28% 校验正确率: 98.86% 训练周期: 18 [51264/60000 (85%)] Loss: 0.014138 训练正确率: 98.29% 校验正确率: 98.82% 训练周期: 18 [57664/60000 (96%)] Loss: 0.059072 训练正确率: 98.27% 校验正确率: 98.76% 训练周期: 19 [64/60000 (0%)] Loss: 0.064476 训练正确率: 95.31% 校验正确率: 98.80% 训练周期: 19 [6464/60000 (11%)] Loss: 0.045951 训练正确率: 98.02% 校验正确率: 98.78% 训练周期: 19 [12864/60000 (21%)] Loss: 0.009943 训练正确率: 98.28% 校验正确率: 98.88% 训练周期: 19 [19264/60000 (32%)] Loss: 0.006517 训练正确率: 98.28% 校验正确率: 98.72% 训练周期: 19 [25664/60000 (43%)] Loss: 0.117733 训练正确率: 98.32% 校验正确率: 99.00% 训练周期: 19 [32064/60000 (53%)] Loss: 0.012346 训练正确率: 98.30% 校验正确率: 98.82% 训练周期: 19 [38464/60000 (64%)] Loss: 0.004596 训练正确率: 98.30% 校验正确率: 98.86% 训练周期: 19 [44864/60000 (75%)] Loss: 0.057226 训练正确率: 98.31% 校验正确率: 98.72% 训练周期: 19 [51264/60000 (85%)] Loss: 0.045235 训练正确率: 98.33% 校验正确率: 99.00% 训练周期: 19 [57664/60000 (96%)] Loss: 0.069116 训练正确率: 98.33% 校验正确率: 98.94%
In [11]
paddle.save(net.state_dict(), 'minst_conv_checkpoint.pdparams')
'''
PaddlePaddle:使用paddle.save()保存模型参数。
PyTorch:使用torch.save()保存模型参数。
'''
'\nPaddlePaddle:使用paddle.save()保存模型参数。\nPyTorch:使用torch.save()保存模型参数。\n'
In [12]
# 提取第一层卷积层的卷积核
plt.figure(figsize=(10, 7))
for i in range(4):
plt.subplot(1, 4, i + 1)
plt.imshow(net.conv1.weight.numpy()[i, 0, ...])
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2349: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working if isinstance(obj, collections.Iterator): /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/cbook/__init__.py:2366: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working return list(data) if isinstance(data, collections.MappingView) else data /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:425: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead a_min = np.asscalar(a_min.astype(scaled_dtype)) /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/image.py:426: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead a_max = np.asscalar(a_max.astype(scaled_dtype))
<Figure size 1000x700 with 4 Axes>
In [13]
# 绘制第二层的卷积核
plt.figure(figsize=(15, 10))
for i in range(4):
for j in range(8):
plt.subplot(4, 8, i * 8 + j + 1)
plt.imshow(net.conv2.weight.numpy()[j, i, ...])
<Figure size 1500x1000 with 32 Axes>