Pytorch之MNIST数据集的训练和测试

训练和测试的完整代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Dataset
import argparse
import os


# 训练
def train(args, model, device, train_loader, optimizer):
    model.train()
    num_correct = 0
    for batch_index, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
		
		# forward
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)

		# backward
		optimizer.zero_grad()   # 梯度清空
        loss.backward()			# 梯度回传,更新参数
        optimizer.step()

        _, predicted = torch.max(outputs, dim=1)

        # 每一个batch预测对的个数
        batch_correct = (predicted == labels).sum().item()
        # 每一个batch的准确率
        batch_accuracy = batch_correct / args.batch_size

        # 每一个epoch预测对的总个数
        num_correct += (predicted == labels).sum().item()

        # print sth.
        print(f'Epoch:{epoch},Batch ID:{batch_index}/{len(train_loader)}, loss:{loss}, Batch accuracy:{batch_accuracy*100}%')

    # 每一个epoch的准确率
    epoch_accuracy = num_correct / len(train_loader.dataset)

    # print epoch_accuracy
    print(f'Epoch Accuracy:{epoch_accuracy}')
	
	# 保存模型
    if epoch % args.checkpoint_interval == 0:
        torch.save(model.state_dict(), f"checkpoints/VGG16_MNIST_%d.pth" % epoch)



# 验证
def test(args, model, device, test_loader):
    model.eval()
    total_loss = 0
    num_correct = 0
    if args.pretrained_weights.endswith(".pth"):
        model.load_state_dict(torch.load(args.pretrained_weights))

    # 不计算梯度,节省计算资源
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            output = model(images)

            # 总的loss
            total_loss += F.cross_entropy(output, labels).item()     # item()用于取出tensor里边的值

            # torch.max():返回的是两个值,第一个值是具体的value,第二个值是value所在的index
            _, predicted = torch.max(output, dim=1)

            # 预测对的总个数
            num_correct += (predicted == labels).sum().item()

    # 平均loss
    test_loss = total_loss / len(test_loader.dataset)
    # 平均准确率
    accuracy = num_correct / len(test_loader.dataset)
    
	# print sth.
    print(f'Average loss:{test_loss}\nTest Accuracy:{accuracy*100}%')


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description = 'Pytorch-MNIST_classification')
    parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
    parser.add_argument('--batch_size', type=int, default=32, help='size of each image batch' )
    parser.add_argument('--num_classes', type=int, default=10, help='number of classes')
    parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
    parser.add_argument('--momentum', type=float, default=0.9, help='SGD momentum')
    parser.add_argument('--pretrained_weights', type=str, default='checkpoints/', help='if specified starts from checkpoint model')
    parser.add_argument("--img_size", type=int, default=224, help="size of each image dimension")
    parser.add_argument("--checkpoint_interval", type=int, default=1, help="interval between saving model weights")
    parser.add_argument("--train",  default=False, help="train or test")

    args = parser.parse_args()
    print(args)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # os.makedirs() 方法用于递归创建目录
    os.makedirs("output", exist_ok=True)
    os.makedirs("checkpoints", exist_ok=True)

    # transform
    data_transform = transforms.Compose([transforms.ToTensor(), transforms.RandomResizedCrop(args.img_size)])

    # 下载训练数据
    train_data = datasets.MNIST(root = 'data',
                                train = True,
                                transform = data_transform,
                                target_transform = None,
                                download = True)
    # 下载测试数据
    test_data = datasets.MNIST(root = 'data',
                               train = False,
                               transform = data_transform,
                               target_transform = None,
                               download = True)


    # 加载训练数据
    train_loader = DataLoader(dataset = train_data,
                                  batch_size = args.batch_size,
                                  shuffle = True)
    # 加载测试数据
    test_loader = DataLoader(dataset = test_data,
                                 batch_size = args.batch_size)

    # 创建模型
    model = models.vgg16(pretrained = True)
    # 修改vgg16的输出维度
    model.classifier[6] = nn.Linear(in_features=4096, out_features=args.num_classes, bias=True)
    # MNIST数据集是灰度图,channel数为1
    model.features[0] = nn.Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    print(model)
    model = model.to(device)

    # 优化器(也可以选择其他优化器)
    optimizer = torch.optim.SGD(model.parameters(), lr = args.lr, momentum = args.momentum)
    # optimizer = torch.optim.Adam()

    if args.train == True:
        for epoch in range(1, args.epochs+1):
            # 是否加载预训练好的权重
            if args.pretrained_weights.endswith(".pth"):
                model.load_state_dict(torch.load(args.pretrained_weights))
            train(args, model, device, train_loader, optimizer)
    else:
        # 是否加载预训练好的权重
        if args.pretrained_weights.endswith(".pth"):
            model.load_state_dict(torch.load(args.pretrained_weights))
        test(args, model, device, test_loader)

测试结果:
在这里插入图片描述        我只训练了不到10轮,效果不是太好,还有提升空间。

说明:
        MNIST数据集可以通过trochvision中的datasets.MNIST下载,也可以自己下载(注意存放路径);我模型使用的是torchvision中的models中预训练好的vgg16网络,也可以自己搭建网络。

猜你喜欢

转载自blog.csdn.net/Roaddd/article/details/112184020