Hands-on MNIST classification (CPU version + GPU version)

1. Introduction to the dataset

Before starting MNIST, I believe that everyone is familiar with the MNIST data set. I will not go into details here, but only list the key parameters:
Category : Numbers 0~9, a total of 10 categories;
Quantity : A total of 70,000 Grayscale images, 60,000 for training, 10,000 for testing, and each image has a label, number 0 corresponds to label 0, number 1 corresponds to label 1, and so on...
Pixel size : 28×28
Number of channels : single channel

2. Introduction to the network

Here is a relatively simple network, most of which can be seen in various forums and papers and journals, and I will not go into details here.
Define the network class:

import torch.nn as nn

class Net(nn.Module):
    def __init__(self, in_c=784, out_c=10):
        super(Net, self).__init__()

        # 定义全连接层
        self.fc1 = nn.Linear(in_c, 512)
        # 定义激活层
        self.act1 = nn.ReLU(inplace=True)

        self.fc2 = nn.Linear(512, 256)
        self.act2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(256, 128)
        self.act3 = nn.ReLU(inplace=True)

        self.fc4 = nn.Linear(128, out_c)

    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        x = self.act3(self.fc3(x))
        x = self.fc4(x)
        return x
复制代码

3. Hands-on CPU version

网络具备,以此开始动收撸CPU版本的MNIST分类
流程
3.1 获取类(网络)
3.2 获取训练集和测试集
3.3 训练集和测试集载入(DataLoader)
3.4 定义损失函数--交叉熵函数
3.5 定义优化器--随机梯度下降
3.6 创建for循环A并创建空数组分别记录训练损失和精度
3.7 for循环A中:创建空数组分布记录测试损失和精度(注意此处应与训练分开)
3.8 设置训练次数
3.9 构建for循环B并在for循环中设置训练精度和损失值
3.10 在for循环B中:网络开始训练载入
3.11 在for循环B中:构建batch 和图像以及标签的for循环
3.12 在for循环B中:图像和标签Variable处理
3.13 在for循环B中:网络向前传播out = net(img) loss = criterion(out, label)
3.14 在for循环B中:记录误差
3.15 在for循环B中:计算分类正确率
3.16 结束B循环:此时还在A循环中,append损失和append精度
3.17 在A循环中新建循环此时的为测试集(测试集不训练)其他不变

CPU版本的demo:

import time
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torchvision import transforms



class Net(nn.Module):
    def __init__(self, in_c=784, out_c=10):
        super(Net, self).__init__()

        # 定义全连接层
        self.fc1 = nn.Linear(in_c, 512)
        # 定义激活层
        self.act1 = nn.ReLU(inplace=True)

        self.fc2 = nn.Linear(512, 256)
        self.act2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(256, 128)
        self.act3 = nn.ReLU(inplace=True)

        self.fc4 = nn.Linear(128, out_c)

    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        x = self.act3(self.fc3(x))
        x = self.fc4(x)
        return x
t1 = time.time()
# 搭建网络
net = Net()
# 训练集
train_set = mnist.MNIST('./data', train=True, transform=transforms.ToTensor(), download=False)
# 测试集
test_set = mnist.MNIST('./data', train=False, transform=transforms.ToTensor(), download=False)
# 训练集载入器
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
# 测试集载入器
test_data = DataLoader(test_set, batch_size=64, shuffle=True)

# 定义损失函数 -- 交叉熵
criterion = nn.CrossEntropyLoss()
# 定义优化器 -- 随机梯度下降
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=0.00005)

# 开始训练
losses = []  # 记录训练损失
acces = []  # 记录训练精度

eval_losses = []  # 记录测试损失
eval_acces = []   # 记录测试精度
nums_epoch = 20  # 训练次数

for epoch in range(nums_epoch):
    train_loss = 0  # 设置训练损失的初始值
    train_acc = 0   # 设置训练精度的初始值
    net.train()
    for batch, (img, label) in enumerate(train_data):
        img = img.reshape(img.size(0), -1)
        img = Variable(img)
        label = Variable(label)
        # 向前传播
        out = net(img)
        loss = criterion(out, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 记录误差
        train_loss +=loss.item()
        # 计算分类正确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        if (batch +1) % 200 == 0:
            print('[INFO] Epoch-{}-Batch-{}: Train: Loss-{:.4f},Accuracy-{:.4f}'.format(epoch+1, batch+1, loss.item(),acc))
            train_acc += acc

    losses.append(train_acc / len(train_data))
    acces.append(train_acc / len(train_data))
    eval_loss = 0
    eval_acc = 0
    # 测试集不训练
    for img, label in test_data:
        img = img.reshape(img.size(0),-1)
        img = Variable(img)
        label = Variable(label)

        out = net(img)
        loss = criterion(out, label)
        eval_loss += loss.item()
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        eval_acc += acc
    eval_losses.append(eval_loss / len(test_data))
    eval_acces.append(eval_acc / len(test_data))
    # 打印参数
    set_epoch = epoch+1
    set_lossTrain = train_loss / len(train_data)
    set_AccTrain = train_acc / len(train_data)
    set_lossEval = eval_loss / len(test_data)
    set_AccEval = eval_acc / len(test_data)

    print('[INFO] Epoch-{}: Train: Loss-{:.4f},Accuracy-{:.4f} |Test:Loss-{:.4f}, Accuracy-{:.4f}'.format(set_epoch,
    set_lossTrain, set_AccTrain, set_lossEval, set_AccEval))
t2 = time.time()
t = t2 - t1
print(t)
复制代码

CPU版本的demo(输出):

[INFO] Epoch-1-Batch-200: Train: Loss-2.2869,Accuracy-0.2500
[INFO] Epoch-1-Batch-400: Train: Loss-2.2628,Accuracy-0.4062
[INFO] Epoch-1-Batch-600: Train: Loss-2.2056,Accuracy-0.5156
[INFO] Epoch-1-Batch-800: Train: Loss-1.9502,Accuracy-0.6875
[INFO] Epoch-1: Train: Loss-2.1569,Accuracy-0.0020 |Test:Loss-1.5634, Accuracy-0.6395
[INFO] Epoch-2-Batch-200: Train: Loss-0.9539,Accuracy-0.8125
[INFO] Epoch-2-Batch-400: Train: Loss-0.8835,Accuracy-0.7500
[INFO] Epoch-2-Batch-600: Train: Loss-0.5718,Accuracy-0.8281
[INFO] Epoch-2-Batch-800: Train: Loss-0.5851,Accuracy-0.8125
[INFO] Epoch-2: Train: Loss-0.8250,Accuracy-0.0034 |Test:Loss-0.5095, Accuracy-0.8504
[INFO] Epoch-3-Batch-200: Train: Loss-0.4938,Accuracy-0.7500
[INFO] Epoch-3-Batch-400: Train: Loss-0.5644,Accuracy-0.8438
[INFO] Epoch-3-Batch-600: Train: Loss-0.4656,Accuracy-0.8750
[INFO] Epoch-3-Batch-800: Train: Loss-0.4800,Accuracy-0.8438
[INFO] Epoch-3: Train: Loss-0.4432,Accuracy-0.0035 |Test:Loss-0.3720, Accuracy-0.8914
[INFO] Epoch-4-Batch-200: Train: Loss-0.3568,Accuracy-0.8750
[INFO] Epoch-4-Batch-400: Train: Loss-0.3659,Accuracy-0.8594
[INFO] Epoch-4-Batch-600: Train: Loss-0.3843,Accuracy-0.8281
[INFO] Epoch-4-Batch-800: Train: Loss-0.3291,Accuracy-0.8906
[INFO] Epoch-4: Train: Loss-0.3600,Accuracy-0.0037 |Test:Loss-0.3328, Accuracy-0.9015
[INFO] Epoch-5-Batch-200: Train: Loss-0.2843,Accuracy-0.8906
[INFO] Epoch-5-Batch-400: Train: Loss-0.2729,Accuracy-0.9375
[INFO] Epoch-5-Batch-600: Train: Loss-0.2628,Accuracy-0.9219
[INFO] Epoch-5-Batch-800: Train: Loss-0.1479,Accuracy-0.9531
[INFO] Epoch-5: Train: Loss-0.3174,Accuracy-0.0039 |Test:Loss-0.2917, Accuracy-0.9161
[INFO] Epoch-6-Batch-200: Train: Loss-0.3273,Accuracy-0.9062
[INFO] Epoch-6-Batch-400: Train: Loss-0.2906,Accuracy-0.9375
[INFO] Epoch-6-Batch-600: Train: Loss-0.2957,Accuracy-0.9062
[INFO] Epoch-6-Batch-800: Train: Loss-0.2804,Accuracy-0.9375
[INFO] Epoch-6: Train: Loss-0.2839,Accuracy-0.0039 |Test:Loss-0.2652, Accuracy-0.9247
[INFO] Epoch-7-Batch-200: Train: Loss-0.3675,Accuracy-0.8906
[INFO] Epoch-7-Batch-400: Train: Loss-0.3041,Accuracy-0.8906
[INFO] Epoch-7-Batch-600: Train: Loss-0.2421,Accuracy-0.9375
[INFO] Epoch-7-Batch-800: Train: Loss-0.1761,Accuracy-0.9219
[INFO] Epoch-7: Train: Loss-0.2561,Accuracy-0.0039 |Test:Loss-0.2401, Accuracy-0.9319
[INFO] Epoch-8-Batch-200: Train: Loss-0.1390,Accuracy-0.9531
[INFO] Epoch-8-Batch-400: Train: Loss-0.1204,Accuracy-0.9688
[INFO] Epoch-8-Batch-600: Train: Loss-0.1118,Accuracy-0.9844
[INFO] Epoch-8-Batch-800: Train: Loss-0.1276,Accuracy-0.9844
[INFO] Epoch-8: Train: Loss-0.2306,Accuracy-0.0041 |Test:Loss-0.2178, Accuracy-0.9365
[INFO] Epoch-9-Batch-200: Train: Loss-0.4543,Accuracy-0.9062
[INFO] Epoch-9-Batch-400: Train: Loss-0.3267,Accuracy-0.9219
[INFO] Epoch-9-Batch-600: Train: Loss-0.1870,Accuracy-0.9531
[INFO] Epoch-9-Batch-800: Train: Loss-0.3354,Accuracy-0.9062
[INFO] Epoch-9: Train: Loss-0.2094,Accuracy-0.0039 |Test:Loss-0.2016, Accuracy-0.9412
[INFO] Epoch-10-Batch-200: Train: Loss-0.1400,Accuracy-0.9219
[INFO] Epoch-10-Batch-400: Train: Loss-0.2871,Accuracy-0.9219
[INFO] Epoch-10-Batch-600: Train: Loss-0.1343,Accuracy-0.9531
[INFO] Epoch-10-Batch-800: Train: Loss-0.2881,Accuracy-0.8906
[INFO] Epoch-10: Train: Loss-0.1906,Accuracy-0.0039 |Test:Loss-0.1805, Accuracy-0.9460
[INFO] Epoch-11-Batch-200: Train: Loss-0.2244,Accuracy-0.9688
[INFO] Epoch-11-Batch-400: Train: Loss-0.1173,Accuracy-0.9688
[INFO] Epoch-11-Batch-600: Train: Loss-0.1551,Accuracy-0.9531
[INFO] Epoch-11-Batch-800: Train: Loss-0.1560,Accuracy-0.9531
[INFO] Epoch-11: Train: Loss-0.1748,Accuracy-0.0041 |Test:Loss-0.1693, Accuracy-0.9504
[INFO] Epoch-12-Batch-200: Train: Loss-0.2438,Accuracy-0.9688
[INFO] Epoch-12-Batch-400: Train: Loss-0.0888,Accuracy-0.9688
[INFO] Epoch-12-Batch-600: Train: Loss-0.0938,Accuracy-0.9688
[INFO] Epoch-12-Batch-800: Train: Loss-0.1019,Accuracy-0.9688
[INFO] Epoch-12: Train: Loss-0.1611,Accuracy-0.0041 |Test:Loss-0.1562, Accuracy-0.9515
[INFO] Epoch-13-Batch-200: Train: Loss-0.2955,Accuracy-0.9219
[INFO] Epoch-13-Batch-400: Train: Loss-0.3402,Accuracy-0.9062
[INFO] Epoch-13-Batch-600: Train: Loss-0.1040,Accuracy-0.9688
[INFO] Epoch-13-Batch-800: Train: Loss-0.1147,Accuracy-0.9844
[INFO] Epoch-13: Train: Loss-0.1491,Accuracy-0.0040 |Test:Loss-0.1475, Accuracy-0.9562
[INFO] Epoch-14-Batch-200: Train: Loss-0.0578,Accuracy-1.0000
[INFO] Epoch-14-Batch-400: Train: Loss-0.0836,Accuracy-0.9688
[INFO] Epoch-14-Batch-600: Train: Loss-0.1362,Accuracy-0.9688
[INFO] Epoch-14-Batch-800: Train: Loss-0.0897,Accuracy-0.9531
[INFO] Epoch-14: Train: Loss-0.1387,Accuracy-0.0041 |Test:Loss-0.1441, Accuracy-0.9561
[INFO] Epoch-15-Batch-200: Train: Loss-0.1424,Accuracy-0.9844
[INFO] Epoch-15-Batch-400: Train: Loss-0.0657,Accuracy-0.9844
[INFO] Epoch-15-Batch-600: Train: Loss-0.0836,Accuracy-0.9688
[INFO] Epoch-15-Batch-800: Train: Loss-0.1404,Accuracy-0.9688
[INFO] Epoch-15: Train: Loss-0.1289,Accuracy-0.0042 |Test:Loss-0.1301, Accuracy-0.9608
[INFO] Epoch-16-Batch-200: Train: Loss-0.1637,Accuracy-0.9219
[INFO] Epoch-16-Batch-400: Train: Loss-0.0509,Accuracy-1.0000
[INFO] Epoch-16-Batch-600: Train: Loss-0.2507,Accuracy-0.9375
[INFO] Epoch-16-Batch-800: Train: Loss-0.0801,Accuracy-0.9688
[INFO] Epoch-16: Train: Loss-0.1205,Accuracy-0.0041 |Test:Loss-0.1252, Accuracy-0.9610
[INFO] Epoch-17-Batch-200: Train: Loss-0.0761,Accuracy-0.9688
[INFO] Epoch-17-Batch-400: Train: Loss-0.0439,Accuracy-1.0000
[INFO] Epoch-17-Batch-600: Train: Loss-0.2204,Accuracy-0.9062
[INFO] Epoch-17-Batch-800: Train: Loss-0.0640,Accuracy-0.9844
[INFO] Epoch-17: Train: Loss-0.1128,Accuracy-0.0041 |Test:Loss-0.1211, Accuracy-0.9617
[INFO] Epoch-18-Batch-200: Train: Loss-0.0907,Accuracy-0.9844
[INFO] Epoch-18-Batch-400: Train: Loss-0.0587,Accuracy-0.9844
[INFO] Epoch-18-Batch-600: Train: Loss-0.0478,Accuracy-1.0000
[INFO] Epoch-18-Batch-800: Train: Loss-0.0532,Accuracy-0.9844
[INFO] Epoch-18: Train: Loss-0.1057,Accuracy-0.0042 |Test:Loss-0.1113, Accuracy-0.9654
[INFO] Epoch-19-Batch-200: Train: Loss-0.1051,Accuracy-0.9531
[INFO] Epoch-19-Batch-400: Train: Loss-0.1953,Accuracy-0.9219
[INFO] Epoch-19-Batch-600: Train: Loss-0.1334,Accuracy-0.9531
[INFO] Epoch-19-Batch-800: Train: Loss-0.1170,Accuracy-0.9531
[INFO] Epoch-19: Train: Loss-0.0991,Accuracy-0.0040 |Test:Loss-0.1087, Accuracy-0.9662
[INFO] Epoch-20-Batch-200: Train: Loss-0.0581,Accuracy-1.0000
[INFO] Epoch-20-Batch-400: Train: Loss-0.0779,Accuracy-0.9531
[INFO] Epoch-20-Batch-600: Train: Loss-0.1448,Accuracy-0.9375
[INFO] Epoch-20-Batch-800: Train: Loss-0.0859,Accuracy-0.9688
[INFO] Epoch-20: Train: Loss-0.0934,Accuracy-0.0041 |Test:Loss-0.1075, Accuracy-0.9661
143.9970304965973
复制代码

四.CPU2GPU改进措施

在进行CPU2GPU的过程中基本流程不变,这里需要借助:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 此语句的含义为当计算机存在cuda时则用GPU,若无GPU时则用CPU。
4.1注意事项

4.1.1. 灵活使用.to(device)可以使网络运行的更为顺畅并且batch也可以更大,更加合理分配资源
4.1.2. 这里需要用到cudnn,因此需要调用:import torch.backends.cudnn as cudnn
4.1.3. 分别对网络、数据、标签、损失函数 进行to(device)
4.1.4. 合理分配更多的batch

4.2具体可如下图所示:
4.2.1对网络进行to(device)和使用cudnn image.png 4.2.2合理分配更多的batch image.png 4.2.3 使用到torch.nn image.png 4.2.4 分别对训练和测试的img label 进行to(device) image.png image.png

GPU版本的demo

import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch import optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torchvision import transforms


class Net(nn.Module):
    def __init__(self, in_c=784, out_c=10):
        super(Net, self).__init__()

        # 定义全连接层
        self.fc1 = nn.Linear(in_c, 512)
        # 定义激活层
        self.act1 = nn.ReLU(inplace=True)

        self.fc2 = nn.Linear(512, 256)
        self.act2 = nn.ReLU(inplace=True)
        self.fc3 = nn.Linear(256, 128)
        self.act3 = nn.ReLU(inplace=True)

        self.fc4 = nn.Linear(128, out_c)

    def forward(self, x):
        x = self.act1(self.fc1(x))
        x = self.act2(self.fc2(x))
        x = self.act3(self.fc3(x))
        x = self.fc4(x)
        return x

t1 = time.time()
# 搭建网络
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = Net()
cudnn.benchmark = True
net = net.to(device)


# 训练集
train_set = mnist.MNIST('./data', train=True, transform=transforms.ToTensor(), download=False)
# 测试集
test_set = mnist.MNIST('./data', train=False, transform=transforms.ToTensor(), download=False)
# 训练集载入器
train_data = DataLoader(train_set, batch_size=640, shuffle=True)
# 测试集载入器
test_data = DataLoader(test_set, batch_size=320, shuffle=True)

# 定义损失函数 -- 交叉熵
criterion = torch.nn.CrossEntropyLoss().to(device)
# 定义优化器 -- 随机梯度下降
optimizer = optim.SGD(net.parameters(), lr=0.01, weight_decay=0.00005)

# 开始训练
losses = []  # 记录训练损失
acces = []  # 记录训练精度

eval_losses = []  # 记录测试损失
eval_acces = []  # 记录测试精度
nums_epoch = 20  # 训练次数

for epoch in range(nums_epoch):
    train_loss = 0  # 设置训练损失的初始值
    train_acc = 0  # 设置训练精度的初始值
    net.train()
    for batch, (img, label) in enumerate(train_data):
        img = img.reshape(img.size(0), -1)
        img = Variable(img)
        img = img.to(device)

        label = Variable(label)
        label = label.to(device)

        # 向前传播
        out = net(img)
        loss = criterion(out, label)
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 记录误差
        train_loss += loss.item()
        # 计算分类正确率
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        if (batch + 1) % 200 == 0:
            print(
                '[INFO] Epoch-{}-Batch-{}: Train: Loss-{:.4f},Accuracy-{:.4f}'.format(epoch + 1, batch + 1, loss.item(),
                                                                                      acc))
            train_acc += acc

    losses.append(train_acc / len(train_data))
    acces.append(train_acc / len(train_data))
    eval_loss = 0
    eval_acc = 0
    # 测试集不训练
    for img, label in test_data:
        img = img.reshape(img.size(0), -1)
        img = Variable(img)
        img = img.to(device)

        label = Variable(label)
        label = label.to(device)

        out = net(img)
        loss = criterion(out, label.to(device))
        eval_loss += loss.item()
        _, pred = out.max(1)
        num_correct = (pred == label).sum().item()
        acc = num_correct / img.shape[0]
        eval_acc += acc
    eval_losses.append(eval_loss / len(test_data))
    eval_acces.append(eval_acc / len(test_data))
    # 打印参数
    set_epoch = epoch + 1
    set_lossTrain = train_loss / len(train_data)
    set_AccTrain = train_acc / len(train_data)
    set_lossEval = eval_loss / len(test_data)
    set_AccEval = eval_acc / len(test_data)

    print('[INFO] Epoch-{}: Train: Loss-{:.4f},Accuracy-{:.4f} |Test:Loss-{:.4f}, Accuracy-{:.4f}'.format(set_epoch,
                                                                                                          set_lossTrain,
                                                                                                          set_AccTrain,
                                                                                                          set_lossEval,
                                                                                                          set_AccEval))
t2 = time.time()
t = t2 - t1
print(t)
复制代码

GPU版本输出

[INFO] Epoch-1: Train: Loss-2.3025,Accuracy-0.0000 |Test:Loss-2.2982, Accuracy-0.1018
[INFO] Epoch-2: Train: Loss-2.2950,Accuracy-0.0000 |Test:Loss-2.2902, Accuracy-0.1034
[INFO] Epoch-3: Train: Loss-2.2867,Accuracy-0.0000 |Test:Loss-2.2812, Accuracy-0.1491
[INFO] Epoch-4: Train: Loss-2.2762,Accuracy-0.0000 |Test:Loss-2.2685, Accuracy-0.3172
[INFO] Epoch-5: Train: Loss-2.2611,Accuracy-0.0000 |Test:Loss-2.2498, Accuracy-0.4545
[INFO] Epoch-6: Train: Loss-2.2383,Accuracy-0.0000 |Test:Loss-2.2205, Accuracy-0.5208
[INFO] Epoch-7: Train: Loss-2.2020,Accuracy-0.0000 |Test:Loss-2.1736, Accuracy-0.5357
[INFO] Epoch-8: Train: Loss-2.1407,Accuracy-0.0000 |Test:Loss-2.0909, Accuracy-0.5215
[INFO] Epoch-9: Train: Loss-2.0353,Accuracy-0.0000 |Test:Loss-1.9529, Accuracy-0.5164
[INFO] Epoch-10: Train: Loss-1.8638,Accuracy-0.0000 |Test:Loss-1.7380, Accuracy-0.5524
[INFO] Epoch-11: Train: Loss-1.6231,Accuracy-0.0000 |Test:Loss-1.4687, Accuracy-0.6281
[INFO] Epoch-12: Train: Loss-1.3598,Accuracy-0.0000 |Test:Loss-1.2147, Accuracy-0.7030
[INFO] Epoch-13: Train: Loss-1.1373,Accuracy-0.0000 |Test:Loss-1.0222, Accuracy-0.7411
[INFO] Epoch-14: Train: Loss-0.9642,Accuracy-0.0000 |Test:Loss-0.8700, Accuracy-0.7768
[INFO] Epoch-15: Train: Loss-0.8361,Accuracy-0.0000 |Test:Loss-0.7663, Accuracy-0.7890
[INFO] Epoch-16: Train: Loss-0.7451,Accuracy-0.0000 |Test:Loss-0.6880, Accuracy-0.8087
[INFO] Epoch-17: Train: Loss-0.6791,Accuracy-0.0000 |Test:Loss-0.6312, Accuracy-0.8223
[INFO] Epoch-18: Train: Loss-0.6299,Accuracy-0.0000 |Test:Loss-0.5862, Accuracy-0.8302
[INFO] Epoch-19: Train: Loss-0.5907,Accuracy-0.0000 |Test:Loss-0.5545, Accuracy-0.8388
[INFO] Epoch-20: Train: Loss-0.5588,Accuracy-0.0000 |Test:Loss-0.5313, Accuracy-0.8452
84.11509966850281
复制代码

自己敲的应该还不太对,还望指正 |改自阿里云天池

Guess you like

Origin juejin.im/post/7077831793012916261