卷积神经网络与Pytorch实践(二)

上一篇文章中,我们介绍了卷积神经网络的基本理论,并使用PyTorch搭建一个类似LeNet-5的网络结构,用于Fashion-MNIST数据集的图像分类。不过上一篇文章中搭建卷积网络使用的是普通卷积,本篇文章我们继续用PyTorch搭建一个类似LeNet-5的网络结构,不过我们这次使用空洞卷积核,搭建的是一个空洞卷积神经网络。

在PyTorch库中使用nn.Conv2d()函数,通过调节参数dilation的取值,进行不同大小卷积核的空洞卷积运算。针对搭建的空洞卷积神经网络,使用如下图所示的网络结构。

 上图中搭建的卷积神经网络含有两个空洞卷积层,两个池化层以及两个全连接层,并且分类器包含10个神经元。该网络结构除了卷积方式的差异之外,与上一节搭建的网络卷积结构完全相同。

下面搭建一个空洞卷积神经网络:

class MyConvdilaNet(nn.Module):
    def __init__(self):
        super(MyConvdilaNet, self).__init__()
        #定义第一个卷积层
        self.conv1=nn.Sequential(
            #卷积后:(1*28*28)->(16*26*26)
            nn.Conv2d(1,16,3,1,1,dilation=2),
            nn.ReLU(),
            nn.AvgPool2d(2,2)#(16*26*26)->(16*13*13)
        )
        #定义第二个卷积层
        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,3,1,0,dilation=2),
            #卷积操作:(16*13*13)->(32*9*9)
            nn.ReLU(),#激活函数
            nn.AvgPool2d(2,2)#最大值池化操作:(32*9*9)->(32*4*4)
        )
        self.classify=nn.Sequential(
            nn.Linear(32*4*4,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,10)
        )
    #定义网络的前向传播路径
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(x.size(0),-1)#展开多维的卷积图层
        output=self.classify(x)
        return output
MyConvdilaNet=MyConvdilaNet()

同样通过nn.Sequential()、nn.Conv2d()、nn.ReLU()、nn.AvgPool2d()、nn.Linear()等层定义了一个拥有两个空洞卷积层和三个全连接层的卷积神经网络,其中在空洞卷积中使用参数dilation=2来实现,最后在forward()函数中定义了数据在网络中的前向传播过程。

接下来使用上一节已经定义好的训练函数train_model()对网络进行训练,并使用折线图对训练过程中的精度和损失函数进行可视化

optimizer=torch.optim.Adam(MyConvdilaNet.parameters(),lr=0.0003)
criterion=nn.CrossEntropyLoss()
MyConvdilaNet,train_process=train_model(MyConvdilaNet,train_loader,0.8,criterion,optimizer,num_epochs=25)

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(train_process.epoch,train_process.train_loss_all,"ro-",label="Train loss")
plt.plot(train_process.epoch,train_process.val_loss_all,"bs-",label="Val loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.subplot(1,2,2)
plt.plot(train_process.epoch,train_process.train_acc_all,"ro-",label="Train acc")
plt.plot(train_process.epoch,train_process.val_acc_all,"bs-",label="Val acc")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend()
plt.show()

 从模型训练过程中可以看出,损失函数在训练集上迅速减小,在验证集上先减小然后逐渐收敛到一个很小的区间,说明模型已经稳定。在训练集上的精度在一直增大,而在验证集上的精度收敛到一个小区间内。对测试样本的预测结果,我们使用混淆矩阵可视化每个类别的预测情况:

MyConvdilaNet.eval()
output=MyConvdilaNet(test_data_x)
pre_lab=torch.argmax(output,1)
acc=accuracy_score(test_data_y,pre_lab)
print("在测试集上的预测精度为:",acc)
conf_mat=confusion_matrix(test_data_y,pre_lab)
df_cm=pd.DataFrame(conf_mat,index=class_label,columns=class_label)
heatmap=sns.heatmap(df_cm,annot=True,fmt="d",cmap="YlGnBu")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(),rotation=0,ha="right")
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(),rotation=45,ha="right")
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.show()

 通过混淆矩阵,我们发现模型最容易把T-shirt与Shirt辨认混淆。

 该测试案例的完整代码如下:

import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score,confusion_matrix,classification_report
import matplotlib.pyplot as plt
import seaborn as sns
import copy
import time
import torch
import torch.nn as nn
from torch.optim import Adam
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
#准备数据集
train_data=FashionMNIST(
    root="./Dataset/FashionMNIST",
    train=True,
    transform=transforms.ToTensor(),
    download=False
)
#定义一个数据加载器
train_loader=Data.DataLoader(
    dataset=train_data,
    batch_size=64,
    shuffle=True,
    num_workers=0
)
test_data=FashionMNIST(
    root="./Dataset/FashionMNIST",
    train=False,
    download=False
)
test_data_x=test_data.data.type(torch.FloatTensor)/255.0
test_data_x=torch.unsqueeze(test_data_x,dim=1)
test_data_y=test_data.targets
class_label=train_data.classes
class MyConvdilaNet(nn.Module):
    def __init__(self):
        super(MyConvdilaNet, self).__init__()
        #定义第一个卷积层
        self.conv1=nn.Sequential(
            #卷积后:(1*28*28)->(16*26*26)
            nn.Conv2d(1,16,3,1,1,dilation=2),
            nn.ReLU(),
            nn.AvgPool2d(2,2)#(16*26*26)->(16*13*13)
        )
        #定义第二个卷积层
        self.conv2=nn.Sequential(
            nn.Conv2d(16,32,3,1,0,dilation=2),
            #卷积操作:(16*13*13)->(32*9*9)
            nn.ReLU(),#激活函数
            nn.AvgPool2d(2,2)#最大值池化操作:(32*9*9)->(32*4*4)
        )
        self.classify=nn.Sequential(
            nn.Linear(32*4*4,256),
            nn.ReLU(),
            nn.Linear(256,128),
            nn.ReLU(),
            nn.Linear(128,10)
        )
    #定义网络的前向传播路径
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=x.view(x.size(0),-1)#展开多维的卷积图层
        output=self.classify(x)
        return output
MyConvdilaNet=MyConvdilaNet()
#定义训练函数
def train_model(model,traindataloader,train_rate,crition,optimizer,num_epochs=25):
    batch_num=len(traindataloader)
    train_batch_num=round(batch_num * train_rate)
    best_model_wts=copy.deepcopy(model.state_dict())
    best_acc=0.0
    train_loss_all=[]
    train_acc_all=[]
    val_loss_all=[]
    val_acc_all=[]
    since=time.time()
    for epoch in range(num_epochs):
        train_loss=0.0
        train_corrects=0
        train_num=0
        val_loss=0.0
        val_corrects = 0
        val_num = 0
        for step,(b_x,b_y) in enumerate(traindataloader):
            if step<train_batch_num:
                model.train()
                output=model(b_x)
                pre_lab=torch.argmax(output,1)
                loss=crition(output,b_y)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                train_loss+=loss.item() *b_y.size(0)
                train_corrects+=torch.sum(pre_lab==b_y.data)
                train_num+=b_x.size(0)
            else:
                model.eval()
                output=model(b_x)
                pre_lab=torch.argmax(output,1)
                loss=crition(output,b_y)
                val_loss+=loss.item() * b_x.size(0)
                val_corrects+=torch.sum(pre_lab==b_y.data)
                val_num+=b_x.size(0)
        train_loss_all.append(train_loss/train_num)
        train_acc_all.append(train_corrects.double().item()/train_num)
        val_loss_all.append(val_loss/val_num)
        val_acc_all.append(val_corrects/val_num)
        if val_acc_all[-1] >best_acc:
            best_acc=val_acc_all[-1]
            best_model_wts=copy.deepcopy(model.state_dict())
        time_use=time.time()-since
    model.load_state_dict(best_model_wts)
    train_process=pd.DataFrame(
        data={
            "epoch":range(num_epochs),
            "train_loss_all":train_loss_all,
            "val_loss_all":val_loss_all,
            "train_acc_all":train_acc_all,
            "val_acc_all":val_acc_all
        }
    )
    return model,train_process
optimizer=torch.optim.Adam(MyConvdilaNet.parameters(),lr=0.0003)
criterion=nn.CrossEntropyLoss()
MyConvdilaNet,train_process=train_model(MyConvdilaNet,train_loader,0.8,criterion,optimizer,num_epochs=25)

plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(train_process.epoch,train_process.train_loss_all,"ro-",label="Train loss")
plt.plot(train_process.epoch,train_process.val_loss_all,"bs-",label="Val loss")
plt.legend()
plt.xlabel("epoch")
plt.ylabel("Loss")
plt.subplot(1,2,2)
plt.plot(train_process.epoch,train_process.train_acc_all,"ro-",label="Train acc")
plt.plot(train_process.epoch,train_process.val_acc_all,"bs-",label="Val acc")
plt.xlabel("epoch")
plt.ylabel("acc")
plt.legend()
plt.show()

MyConvdilaNet.eval()
output=MyConvdilaNet(test_data_x)
pre_lab=torch.argmax(output,1)
acc=accuracy_score(test_data_y,pre_lab)
print("在测试集上的预测精度为:",acc)
conf_mat=confusion_matrix(test_data_y,pre_lab)
df_cm=pd.DataFrame(conf_mat,index=class_label,columns=class_label)
heatmap=sns.heatmap(df_cm,annot=True,fmt="d",cmap="YlGnBu")
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(),rotation=0,ha="right")
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(),rotation=45,ha="right")
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.show()


猜你喜欢

转载自blog.csdn.net/qq_42681787/article/details/129672335