PyTorch implements Pokemon dataset classification based on ResNet18 transfer learning

1. Implementation process

1. Data set description

The datasets are divided into 5 categories, as follows:

  • Pikachu: 234
  • Mewtwo: 239
  • Jenny Turtle: 223
  • Little Fire Dragon: 238
  • Frog Seeds: 234

Self-fetching link: https://pan.baidu.com/s/1bsppVXDRsweVKAxSoLy4sw
Extraction code: 9fqo
Image file extensions have 4 types of jpg, jepg, png and gif, and the sizes of the images are not the same, so it is necessary to , verification and test) images are resized and other operations. In this paper, the image size is resized to 224×224 size.

2. Data preprocessing

This paper uses the Dataset framework to preprocess the dataset, and converts the image dataset into a mapping relationship such as {images, labels}.

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize

        self.name2label = {
    
    }    # "sq...": 0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name] = len(self.name2label.keys())
        # print(self.name2label)

        # image,label
        self.images, self.labels = self.load_csv('images.csv')

        # 数据集裁剪:训练集、验证集、测试集
        if mode == 'train': # 60%
            self.images = self.images[0:int(0.6*len(self.images))]
            self.labels = self.labels[0:int(0.6*len(self.labels))]
        elif mode == 'val': # 20% = 60% -> 80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:               # 20% = 80% -> 100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

Among them, root represents the file root directory where the dataset is stored; resize represents the uniform size of the dataset output; mode represents the mode (train, val, and test) when reading the dataset; name2label is to construct a dictionary structure of image category names and labels, It is convenient to obtain the label of the image category; the load_csv method is to create a mapping relationship of {images, labels}, where images represents the file path where the image is located, and the code is as follows:

    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            # 文件不存在,则需要创建该文件
            images = []
            for name in self.name2label.keys():
                # pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
                images += glob.glob(os.path.join(self.root, name, '*.gif'))
            # 1168, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images),images)
            # 保存成image,label的csv文件
            random.shuffle(images)
            with open(os.path.join(self.root, filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                # print('writen into csv file:',filename)
        # 加载已保存的csv文件
        images, labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels

The code to get the dataset size and index element position is:

    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # idx:[0, len(self.images)]
        # self.images, self.labels
        # img:'G:/datasets/pokemon\\charmander\\00000182.png'
        # label: 0,1,2,3,4
        img, label = self.images[idx], self.labels[idx]
        transform = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path => image data
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            transforms.RandomRotation(15),      # 随机旋转
            transforms.CenterCrop(self.resize), # 中心裁剪
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485,0.456,0.406],
            #                      std=[0.229,0.224,0.225])
            transforms.Normalize(mean=[0.6096, 0.7286, 0.5103],
                                 std=[1.5543, 1.4887, 1.5958])
        ])

        img = transform(img)
        label = torch.tensor(label)
        return img, label

Among them, please refer to the calculation of mean and std in transforms.Normalize , or directly use the empirical values ​​mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
The image of batch_size=32 displayed by the Visdom visualization tool is shown in the following figure:
insert image description here

2. Design model

This paper adopts the idea of ​​migration learning, directly uses the resnet18 classifier, retains its first 17 layers of network structure, and modifies the last layer accordingly. The code is as follows:

trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],     # [b,512,1,1]
                      Flatten(),   # [b,512,1,1] => [b,512]
                      nn.Linear(512, 5)
                      ).to(device)

Among them, Flatten() is the data flattening method, the code is as follows:

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

3. Construct the loss function and optimizer

The loss function uses cross entropy, the optimizer uses Adam, and the learning rate is set to 0.001. The code is as follows:

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

4. Train, Validate, and Test

	best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):
        for step, (x,y) in enumerate(train_loader):
            # x: [b,3,224,224]  y: [b]
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        # 验证集
        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(model.state_dict(), 'best.mdl')
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch+1)
    # 加载最好的模型
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')
    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)
def evaluate(model, loader):
    correct = 0
    total = len(loader.dataset)
    for (x, y) in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            output = model(x)
            pred = output.argmax(dim=1)
            correct += torch.eq(pred, y).sum().item()
    return correct/total

5. Test results

The change curve of the loss value of the training set and the change curve of the accuracy of the test set are shown in the following figure: The
insert image description hereconsole output is:

best acc: 0.9358974358974359 best epoch: 3
loaded from ckpt!

test acc: 0.9401709401709402

This shows that: when epoch=3, the accuracy of the validation set reaches the highest, and the model at this time can be considered as the best model, and it is used for the test of the test set, reaching an accuracy of 94.02%.

2. References

[1] https://www.bilibili.com/video/BV1f34y1k7fi?p=106
[2] https://blog.csdn.net/Weary_PJ/article/details/122765199

Guess you like

Origin blog.csdn.net/weixin_43821559/article/details/123561478