一、前言
写这篇博客的目的,是想让一些AI小白能够快速上手Pytorch AI框架,对于大佬们是不适用的哦!本文主要是基于以下几个方面展开的:
- PyTorch的一些简单介绍
- 单机训练(CPU/GPU)
- MNIST数据集load方式
二、PyTorch简介
这段不是重点,等有时间再补,哈哈哈~~
三、单机训练(cpu/gpu)
1、 首先可以定义了一些外置输入参数,方便大家调参,如下:
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batchsize', '-b', type=int, default=64,
help='input batch size for training (default: 64)')
parser.add_argument('--test-batchsize', '-tb', type=int, default=1000,
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', '-e', type=int, default=10,
help='number of epochs to train (default: 10)')
parser.add_argument('--gpu', '-g', type=int, default=0,
help='Number of GPU in each mini-batch')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', '-sm', action='store_true', default=False,
help='For Saving the current Model')
这里主要设置了--gpu
参数,可以读取大家给定的gpu的个数,然后就是大家熟知的batchsize,epoch,learning rate等参数。当然pytorch也有接口读取环境中的gpu个数,如下:
import torch
gpu_num = torch.cuda.device_count()
2、 通过重写好的load mnist的类,来本地加载mnist数据集:
train_data = LocalDataset(root + 'train.txt')
test_data = LocalDataset(root + 'test.txt')
3、 根据设置的超参数,初始化网络模型以及优化器
model = Net()
model = model.to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
4、 根据设定的epoch值,开始训练以及测试
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
5、 这里我采用的是简单的LeNet模型,模型如下:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
四、MNIST数据集load方式
在这里我修改了下MNIST数据集的加载方式,直接从本地load读取,然后进行训练。PyTorch官网给的demo,是直接下载MNIST训练的,如下:
root = './dataset'
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(root, train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
设置download=True
,会直接将MNIST数据集下载放在root
文件夹下。这里我重写了MNIST类从本地加载数据集,代码片段:
class LocalDataset(Dataset):
def __init__(self, base_path):
self.data = []
with open(base_path) as fp:
for line in fp.readlines():
tmp = line.split(" ")
self.data.append([tmp[0], tmp[1][7:8]])
self.transformations = \
transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
def __getitem__(self, index):
img = self.transformations(Image.open(self.data[index][0]))
label = int(self.data[index][1])
return img, label
def __len__(self):
return len(self.data)
需要注意这里先下载解压好MNIST数据集,没有的童鞋可以点击这里下载。解压好的数据集格式是二进制,然后需要通过以下DataPreprocess
类进行处理,得到train.txt
以及test.txt
,这两个txt文件有两列,存有图片的具体路径,以及图片label,经tranforms
转换的值。
class DataPreprocess(object):
def __init__(self, root):
self.root = root
@property
def get_train_set(self):
train_set = (
mnist.read_image_file(os.path.join(self.root, 'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(self.root, 'train-labels-idx1-ubyte')))
return train_set
@property
def get_test_set(self):
test_set = (
mnist.read_image_file(os.path.join(self.root, 't10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(self.root, 't10k-labels-idx1-ubyte')))
return test_set
def convert_to_img(self):
f = open(self.root + 'train.txt', 'w')
data_path = self.root + 'train/'
if not os.path.exists(data_path):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(self.get_train_set[0], self.get_train_set[1])):
img_path = data_path + str(i) + '.jpg'
io.imsave(img_path, img.numpy())
f.write(img_path + ' ' + str(label) + '\n')
f.close()
f = open(self.root + 'test.txt', 'w')
data_path = self.root + 'test/'
if not os.path.exists(data_path):
os.makedirs(data_path)
for i, (img, label) in enumerate(zip(self.get_test_set[0], self.get_test_set[1])):
img_path = data_path + str(i) + '.jpg'
io.imsave(img_path, img.numpy())
f.write(img_path + ' ' + str(label) + '\n')
f.close()
上述代码就不展开了,看不明白的童鞋可以留言评论。需查看完整代码的童鞋,可以移步到我的github。