pytorch implementation of Chinese MNIST dataset (Kaggle)

1. Chinese MNIST dataset

          This data set comes from the Kaggle website  Chinese MNIST | Kaggle

It mainly includes 15,000 64*64 handwritten Chinese digital pictures, and a content file.

 2. Neural network structure

         Three-layer fully connected network: 4096*300*80*15

3. Propagation process

          The calculation process of the BP algorithm can refer to the previous article, which is described in detail and will not be repeated.

4. The focus of this project: loading of data sets

           Here we mainly use the method of obtaining the label through the file name. For the specific implementation process, please refer to the video tutorial at station B [absolute dry goods] pytorch loads its own data set, data set loading-video collection

5. Program (pytorch)

# 1 加载必要的库
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
import torchvision
import os
from PIL import Image
from torch.utils.data import  DataLoader,Dataset
import matplotlib.pyplot as plt
from sklearn import preprocessing

# 2 定义超参数
batch_size = 128 #训练每批处理的数据
num_epochs = 10  #训练数据集的轮次

# 3 下载、加载数据
path_dir = "F:\\JetBrains\\PycharmProjects\\pytorchLearning\\Chinese_Digit_Recognition\\data\\data"
# 通过继承Dataset类来进行数据加载
class MyDataset(Dataset): # 继承Dataset
    def __init__(self, path_dir, transform=None):  # 初始化一些属性,获取数据集所在路径的数据列表
        self.path_dir = path_dir  # 文件路径
        self.transform = transform  # 对象进行数据处理
        self.images = os.listdir(self.path_dir)  # 把路径下的所有文件放在一个列表里;即在self.images这个张量中存储path_dir路径的所有文件的名称和后缀名

    def __len__(self): # 返回整个数据集的大小
        return len(self.images)

    def __getitem__(self, index):  # 根据索引index返回图像及标签,索引是根据文件夹内的文件顺序进行排列,从0开始递增
        image_index = self.images[index]  # 根据索引获取图像文件名称
        img_path = os.path.join(self.path_dir, image_index)  # 获取index在确定数值下图片的路径或者目录
        img = Image.open(img_path)  # 读取图像

        # 根据目录名称获取图像标签
        label = img_path.split('\\')[-1].split('.')[0].split('_')[-1]  # 绝对路径后加\\, '\\'的后一位, '.'的前一位就是标签,如cat.0.jpg, 标签就是cat
        #化为int型,并-1与图片数字对应
        label = int(label)
        label = label - 1

        # if self.transform is not None:
        img = self.transform(img)
        return img, label

#加载数据集
train_set = MyDataset(path_dir, transform=torchvision.transforms.ToTensor())
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)


# 4 构建网络模型
class MnistNet(nn.Module):
    def __init__(self):
        super(MnistNet, self).__init__()
        self.fc1 = nn.Linear(1 * 64 * 64, 300)
        self.fc2 = nn.Linear(300, 80)
        self.fc3 = nn.Linear(80, 15)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.fc1(x)  # [batch_size,300]
        x = F.relu(x)  # [batch_size,300]
        x = self.fc2(x)  # [batch_size,80]
        x = F.relu(x)  # [batch_size,80]
        x = self.fc3(x)  # [batch_size,15]
        # return x
        return F.log_softmax(x, dim=-1)

# 5 定义优化器
mnist_net = MnistNet()
optimizer = optim.Adam(mnist_net.parameters(), lr=0.001)
train_loss_list = []
# train_count_list = []

# 6 定义训练方法
def train(epoch):
    mode = True
    mnist_net.train(mode=mode)
    correct, total= 0, 0
    for idx, (data, target) in enumerate(train_loader):
        #将target从Tuple型转换为Tensor型     注:如果先将label转化成了int型,在这里将不需要此转换
        # le = preprocessing.LabelEncoder()
        # target = le.fit_transform(target)
        # target = torch.as_tensor(target)

        optimizer.zero_grad()
        output = mnist_net(data)
        loss = F.nll_loss(output, target)  # 对数似然损失
        # loss = F.cross_entropy(output, target)  # 交叉熵损失
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(output.data, 1)  # 选择最大的(概率)值所在的列数就是他所对应的类别数,
        total += target.size(0)
        correct += (predicted == target).sum().item()
        acc = correct / total
        if idx % 117 == 0 and idx !=0:
        # if idx % 117 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAccuracy: {:.2f} %'
                .format(epoch+1, (idx)* batch_size, len(train_set), 100. * (idx) / len(train_loader), loss.item(), 100 * acc))

        train_loss_list.append(loss.item())
        # train_count_list.append(idx * batch_size + (epoch - 1) * len(train_loader))
'''
# 7 定义测试方法
def test():
    test_loss = 0
    correct = 0
    mnist_net.eval()
    #test_dataloader = get_dataloader(train=False)
    with torch.no_grad():
        for data, target in test_loader:
            output = mnist_net(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()
            pred = output.data.max(1, keepdim=True)[1]  # 获取最大值的位置,[batch_size,1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(test_loader),
        100. * correct / len(test_loader)))
'''
# 8 调用方法6、7
for epoch in range(num_epochs):  # 模型训练迭代次数
        train(epoch)
        # test()

# Save the model checkpoint
torch.save(mnist_net.state_dict(), 'my_handwrite_recognize_model.ckpt')

# 绘制函数
plt.plot(train_count_list)
plt.plot(train_loss_list)
plt.title('Training loss Curve')
plt.ylabel('Loss')
plt.xlabel('epochs')
plt.show()

#可视化验证训练效果
test_dataset = MyDataset(path_dir, transform=torchvision.transforms.ToTensor())
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)
# 随机获取部分训练数据
dataiter = iter(test_dataloader)
data, target = dataiter.next()
output = mnist_net(data)
_, predicted = torch.max(output.data, 1)
# 打印标签、预测
print('  label:', target)
print('predict:', predicted)
import numpy as np
# 定义一个显示图像的函数
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
# # 显示图像
imshow(torchvision.utils.make_grid(data))

Output result:

Train Epoch: 10 [14976/15000 (99%)]	Loss: 0.313809	Accuracy: 96.28 %
  label: tensor([ 5, 11, 10,  0,  5,  9, 12,  1,  4, 11,  7,  8, 12, 14, 10, 12])
predict: tensor([ 5, 11, 10,  0,  5,  9, 12,  1,  4, 11,  7,  8, 12, 14, 10, 12])

Loss curve:

 Note: The code comes from network collection and self-modification, if intruded, it can be deleted.

Guess you like

Origin blog.csdn.net/cxzgood/article/details/121319379