第二课:AlexNet学习


一、搭建AlexNet并训练花分类数据集

使用pytorch搭建花分类器,由于此数据集不像 CIFAR10 那样下载时就划分好了训练集和测试集,因此需要自己划分。花类别共有五种,daisy雏菊,dandelion蒲公英,roses玫瑰,sunflower向日葵,tulips郁金香。点击链接下载花分类数据集 :https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz

二、代码部分

1.module.py----定义AlexNet的网络结构

代码如下(示例):

import torch.nn as nn
import torch


class AlexNet(nn.Module):
    def __init__(self, num_classes=1000, init_weights=False):
        super(AlexNet, self).__init__()
        #用nn.Sequential()将网络打包成一个模块,精简代码
        #卷积层提取图像特征
        self.features = nn.Sequential(
            nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),  # input[3, 224, 224]  output[48, 55, 55]
            nn.ReLU(inplace=True),                                  # 直接修改覆盖原值,节省运算内存
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[48, 27, 27]
            nn.Conv2d(48, 128, kernel_size=5, padding=2),           # output[128, 27, 27]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 13, 13]
            nn.Conv2d(128, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 192, kernel_size=3, padding=1),          # output[192, 13, 13]
            nn.ReLU(inplace=True),
            nn.Conv2d(192, 128, kernel_size=3, padding=1),          # output[128, 13, 13]
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),                  # output[128, 6, 6]
        )
        # 用nn.Sequential()将网络打包成一个模块,精简代码
        # 全连接层对图像分类
        self.classifier = nn.Sequential(
            nn.Dropout(p=0.5),                                      # Dropout 随机失活神经元,默认比例为0.5
            nn.Linear(128 * 6 * 6, 2048),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 2048),
            nn.ReLU(inplace=True),
            nn.Linear(2048, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    # 前向传播过程
    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, start_dim=1)                          # 展平后再传入全连接层
        x = self.classifier(x)
        return x

    # 网络权重初始化,实际上 pytorch 在构建网络时会自动初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):                                               # 若是卷积层
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # 用(何)kaiming_normal_法初始化权重
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)                                       # 初始化偏重为0
            elif isinstance(m, nn.Linear):                                             # 若是全连接层
                nn.init.normal_(m.weight, 0, 0.01)                                     # 正态分布初始化
                nn.init.constant_(m.bias, 0)                                           # 初始化偏重为0

2.train.py----加载数据集并进行训练,训练集计算loss,测试集计算accuracy,保存训练好的网络参数

代码如下(示例):

import os
import sys
import json
import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib.pyplot as plt
import numpy as np
import torch.optim as optim
from tqdm import tqdm
from model import AlexNet


def main():
    # 使用GPU训练
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    #数据预处理
    data_transform = {
    
    
        "train": transforms.Compose([transforms.RandomResizedCrop(224),  # 随机裁剪,再缩放成 224×224
                                     transforms.RandomHorizontalFlip(),  # 水平方向随机翻转,概率为 0.5, 即一半的概率翻转, 一半的概率不翻转
                                     transforms.ToTensor(),
                                     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
        "val": transforms.Compose([transforms.Resize((224, 224)),  # cannot 224, must (224, 224)
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}

    #导入、加载 训练集

    #获取图像数据集的路径
    data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))  # 返回上上层目录
    image_path = os.path.join(data_root, "data_set", "flower_data")  # flower data set path

    # 导入训练集并进行预处理
    assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
    train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
                                         transform=data_transform["train"])
    train_num = len(train_dataset)#打印训练集图片

    #字典,类别:索引  {
    
    'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
    flower_list = train_dataset.class_to_idx
    cla_dict = dict((val, key) for key, val in flower_list.items())# 将 flower_list 中的 key 和 val 调换位置
    # 将 cla_dict 写入 json 文件中
    json_str = json.dumps(cla_dict, indent=4)
    with open('class_indices.json', 'w') as json_file:
        json_file.write(json_str)

    # 按batch_size分批次加载训练集
    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {
    
    } dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset,# 导入的训练集
                                               batch_size=batch_size, shuffle=True,# # 每批训练的样本数,True是否随机打乱训练集
                                               num_workers=nw) # 使用线程数,在windows下设置为0

    # 导入验证集并进行预处理
    validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
                                            transform=data_transform["val"])
    val_num = len(validate_dataset)#打印验证集图片

    # 加载验证集
    validate_loader = torch.utils.data.DataLoader(validate_dataset,# 导入的验证集
                                                  batch_size=4, shuffle=False,
                                                  num_workers=nw)

    print("using {} images for training, {} images for validation.".format(train_num,
                                                                           val_num))
    # test_data_iter = iter(validate_loader)
    # test_image, test_label = test_data_iter.next()
    #
    # def imshow(img):
    #     img = img / 2 + 0.5  # unnormalize
    #     npimg = img.numpy()
    #     plt.imshow(np.transpose(npimg, (1, 2, 0)))
    #     plt.show()

    print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
    imshow(utils.make_grid(test_image))

    net = AlexNet(num_classes=5, init_weights=True) #实例化网络(输出类型为5,初始化权重)

    net.to(device)                                  # 分配网络到指定的设备(GPU/CPU)训练
    loss_function = nn.CrossEntropyLoss()           # 交叉熵损失
    # pata = list(net.parameters())
    optimizer = optim.Adam(net.parameters(), lr=0.0002)   # 优化器(训练参数,学习率)

    epochs = 10
    save_path = './AlexNet.pth'
    best_acc = 0.0
    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()                              # 训练过程中开启 Dropout
        running_loss = 0.0                       # 每个 epoch 都会对 running_loss  清零
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, data in enumerate(train_bar):  # 遍历训练集,step从0开始计算
            images, labels = data                # 获取训练集的图像和标签
            optimizer.zero_grad()                # 清除历史梯度
            outputs = net(images.to(device))     # 正向传播
            loss = loss_function(outputs, labels.to(device))  # 计算损失
            loss.backward()                      # 反向传播
            optimizer.step()                     # 优化器更新参数

            # 打印训练进度(使训练过程可视化)
            running_loss += loss.item()

            train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
                                                                     epochs,
                                                                     loss)

        # 验证
        net.eval()   # 验证过程中关闭 Dropout
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():#在以下步骤中(验证过程中)不用计算每个节点的损失梯度,防止内存占用
            val_bar = tqdm(validate_loader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                outputs = net(val_images.to(device))
                predict_y = torch.max(outputs, dim=1)[1]# 以output中值最大位置对应的索引(标签)作为预测输出
                acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
                
        # 保存准确率最高的那次网络参数
        val_accurate = acc / val_num
        print('[epoch %d] train_loss: %.3f  val_accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')


if __name__ == '__main__':
    main()

3.predict.py——得到训练好的网络参数后,用自己找的图像进行分类测试

代码如下(示例):

import torch
from model import AlexNet
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
import json

# 预处理
data_transform = transforms.Compose(
    [transforms.Resize((224, 224)),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# load image
img = Image.open("tulip.png")
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)

# read class_indict
try:
    json_file = open('./class_indices.json', 'r')
    class_indict = json.load(json_file)
except Exception as e:
    print(e)
    exit(-1)

# create model
model = AlexNet(num_classes=5)
# load model weights
model_weight_path = "./AlexNet.pth"
model.load_state_dict(torch.load(model_weight_path))

# 关闭 Dropout
model.eval()
with torch.no_grad():
    # predict class
    output = torch.squeeze(model(img))     # 将输出压缩,即压缩掉 batch 这个维度
    predict = torch.softmax(output, dim=0)
    predict_cla = torch.argmax(predict).numpy()
print(class_indict[str(predict_cla)], predict[predict_cla].item())
plt.show()

三、额外补充

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/qq_45825952/article/details/124018251