Pytorch之CIFAR1010数据集的训练和测试

代码:

import torch.nn as nn
import torch
import torch.nn.functional as F
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import argparse
import os

# 训练
def train(args, model, device, train_loader, optimizer):
    for epoch in range(1, args.epochs + 1):
        model.train()
        for batch_index, data in enumerate(train_loader):
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)

            # forward
            output = model(images)
            loss = F.cross_entropy(output, labels)

            # backward
            optimizer.zero_grad()  # 梯度清空
            loss.backward()  # 梯度回传,更新参数
            optimizer.step()

            # 打印loss
            print(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}')

        # 保存模型
        if epoch % args.checkpoint_interval == 0:
            torch.save(model.state_dict(), f'checkpoints/cifar10_%d.pth' % epoch)

def test(args, model, device, test_loader):
    model.eval()
    total_loss = 0
    num_correect = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)

            # 总的loss
            total_loss += F.cross_entropy(outputs, labels).item()

            # 预测值
            _, predected = torch.max(outputs, dim=1)

            # 预测对的总个数
            num_correect += (predected==labels).sum().item()

    # 计算平均loss
    average_loss = total_loss / len(test_loader.dataset)

    # 计算准确率
    accuracy = num_correect / len(test_loader.dataset)

    # 打印平均loss和准确率
    print(f'Average loss:{average_loss}\nTest Accuracy:{accuracy*100}%')

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = 'Pytorch-cifar10_classification')
    parser.add_argument('--epochs', type=int, default=10, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch')
    parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
    parser.add_argument('--pretrained_weights', type=str, default='checkpoints/cifar10_17.pth',help='if specified starts from checkpoint model')
    parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")
    parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
    parser.add_argument("--train", default=True, help="train or test")
    args = parser.parse_args()
    print(args)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # os.makedirs() 方法用于递归创建目录
    os.makedirs("output", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    # transform
    data_transform = transforms.Compose([transforms.ToTensor(),
                                         transforms.RandomResizedCrop(args.img_size)
                                        ])

    # 下载训练数据集
    trian_data = datasets.CIFAR10(root = 'data',
                                  train = True,
                                  download = False,
                                 transform = data_transform,
                                 target_transform = None,
                                 )
    # 下载测试数据集
    test_data = datasets.CIFAR10(root = "data",
                                 train = False,
                                 download = False,
                                 transform = data_transform,
                                 target_transform = None)

    # 加载数据
    train_loader = DataLoader(dataset = trian_data,
                              batch_size = args.batch_size,
                              shuffle = True)
    test_loader = DataLoader(dataset = test_data,
                             batch_size = args.batch_size)

    # 创建模型,使用预训练好的权重
    model = models.vgg16(pretrained = True)
    # # 冻结模型,参数不更新
    # for para in model.parameters():
    #     para.requires_grad = False
    # # 只训练全连接层
    # model.classifier[3].requires_grad = True
    # model.classifier[6].requires_grad = True
    # 修改vgg16的输出维度
    model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)
    model = model.to(device)
    # 打印网络结构
    print(model)

    # 定义优化器(也可以选择其他优化器)
    optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)
    # optimizer = torch.optim.Adam(model.parameters())

    if train == True:
        if args.pretrained_weights.endswith(".pth"):
            model.load_state_dict(torch.load(args.pretrained_weights))
        for epoch in range(1, epochs+1):
            train(args, model, device, train_loader, optimizer)
    else:
        if args.pretrained_weights.endswith(".pth"):
            model.load_state_dict(torch.load(args.pretrained_weights))
        test(args, model, device, test_loader)

说明:
        cifar10数据集可以通过trochvision中的datasets.CIFAR10下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。

猜你喜欢

转载自blog.csdn.net/Roaddd/article/details/112139134