PyTorch快速上手篇一 本地加载MNIST数据集进行单机训练

一、前言

写这篇博客的目的,是想让一些AI小白能够快速上手Pytorch AI框架,对于大佬们是不适用的哦!本文主要是基于以下几个方面展开的:

  • PyTorch的一些简单介绍
  • 单机训练(CPU/GPU)
  • MNIST数据集load方式

二、PyTorch简介

这段不是重点,等有时间再补,哈哈哈~~

三、单机训练(cpu/gpu)

1、 首先可以定义了一些外置输入参数,方便大家调参,如下:

	parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='input batch size for training (default: 64)')
    parser.add_argument('--test-batchsize', '-tb', type=int, default=1000,
                        help='input batch size for testing (default: 1000)')
    parser.add_argument('--epochs', '-e', type=int, default=10,
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='Number of GPU in each mini-batch')
    parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
                        help='SGD momentum (default: 0.5)')
    parser.add_argument('--seed', type=int, default=1, metavar='S',
                        help='random seed (default: 1)')
    parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                        help='how many batches to wait before logging training status')
    parser.add_argument('--save-model', '-sm', action='store_true', default=False,
                        help='For Saving the current Model')

这里主要设置了--gpu参数,可以读取大家给定的gpu的个数,然后就是大家熟知的batchsize,epoch,learning rate等参数。当然pytorch也有接口读取环境中的gpu个数,如下:

import torch
gpu_num = torch.cuda.device_count()

2、 通过重写好的load mnist的类,来本地加载mnist数据集:

train_data = LocalDataset(root + 'train.txt')
test_data = LocalDataset(root + 'test.txt')

3、 根据设置的超参数,初始化网络模型以及优化器

model = Net()
model = model.to(device)

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)

4、 根据设定的epoch值,开始训练以及测试

for epoch in range(1, args.epochs + 1):
     train(args, model, device, train_loader, optimizer, epoch)
     test(args, model, device, test_loader)

5、 这里我采用的是简单的LeNet模型,模型如下:

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4 * 4 * 50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

四、MNIST数据集load方式

在这里我修改了下MNIST数据集的加载方式,直接从本地load读取,然后进行训练。PyTorch官网给的demo,是直接下载MNIST训练的,如下:

root = './dataset'
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST(root, train=True, download=True,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ])),

设置download=True,会直接将MNIST数据集下载放在root文件夹下。这里我重写了MNIST类从本地加载数据集,代码片段:

class LocalDataset(Dataset):
    def __init__(self, base_path):
        self.data = []
        with open(base_path) as fp:
            for line in fp.readlines():
                tmp = line.split(" ")
                self.data.append([tmp[0], tmp[1][7:8]])

        self.transformations = \
            transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.1307,), (0.3081,))
                                ])

    def __getitem__(self, index):
        img = self.transformations(Image.open(self.data[index][0]))
        label = int(self.data[index][1])
        return img, label

    def __len__(self):
        return len(self.data)

需要注意这里先下载解压好MNIST数据集,没有的童鞋可以点击这里下载。解压好的数据集格式是二进制,然后需要通过以下DataPreprocess类进行处理,得到train.txt以及test.txt,这两个txt文件有两列,存有图片的具体路径,以及图片label,经tranforms转换的值。

class DataPreprocess(object):
    def __init__(self, root):
        self.root = root

    @property
    def get_train_set(self):
        train_set = (
            mnist.read_image_file(os.path.join(self.root, 'train-images-idx3-ubyte')),
            mnist.read_label_file(os.path.join(self.root, 'train-labels-idx1-ubyte')))
        return train_set

    @property
    def get_test_set(self):
        test_set = (
            mnist.read_image_file(os.path.join(self.root, 't10k-images-idx3-ubyte')),
            mnist.read_label_file(os.path.join(self.root, 't10k-labels-idx1-ubyte')))
        return test_set

    def convert_to_img(self):
        f = open(self.root + 'train.txt', 'w')
        data_path = self.root + 'train/'
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(self.get_train_set[0], self.get_train_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label) + '\n')
        f.close()

        f = open(self.root + 'test.txt', 'w')
        data_path = self.root + 'test/'
        if not os.path.exists(data_path):
            os.makedirs(data_path)
        for i, (img, label) in enumerate(zip(self.get_test_set[0], self.get_test_set[1])):
            img_path = data_path + str(i) + '.jpg'
            io.imsave(img_path, img.numpy())
            f.write(img_path + ' ' + str(label) + '\n')
        f.close()

上述代码就不展开了,看不明白的童鞋可以留言评论。需查看完整代码的童鞋,可以移步到我的github

Guess you like

Origin blog.csdn.net/Zhaopanp_Crise/article/details/100023686