1. Implementation process
1. Data set description
The datasets are divided into 5 categories, as follows:
- Pikachu: 234
- Mewtwo: 239
- Jenny Turtle: 223
- Little Fire Dragon: 238
- Frog Seeds: 234
Self-fetching link: https://pan.baidu.com/s/1bsppVXDRsweVKAxSoLy4sw
Extraction code: 9fqo
Image file extensions have 4 types of jpg, jepg, png and gif, and the sizes of the images are not the same, so it is necessary to , verification and test) images are resized and other operations. In this paper, the image size is resized to 224×224 size.
2. Data preprocessing
This paper uses the Dataset framework to preprocess the dataset, and converts the image dataset into a mapping relationship such as {images, labels}.
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {
} # "sq...": 0
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
# print(self.name2label)
# image,label
self.images, self.labels = self.load_csv('images.csv')
# 数据集裁剪:训练集、验证集、测试集
if mode == 'train': # 60%
self.images = self.images[0:int(0.6*len(self.images))]
self.labels = self.labels[0:int(0.6*len(self.labels))]
elif mode == 'val': # 20% = 60% -> 80%
self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else: # 20% = 80% -> 100%
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
Among them, root represents the file root directory where the dataset is stored; resize represents the uniform size of the dataset output; mode represents the mode (train, val, and test) when reading the dataset; name2label is to construct a dictionary structure of image category names and labels, It is convenient to obtain the label of the image category; the load_csv method is to create a mapping relationship of {images, labels}, where images represents the file path where the image is located, and the code is as follows:
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
# 文件不存在,则需要创建该文件
images = []
for name in self.name2label.keys():
# pokemon\\mewtwo\\00001.png
images += glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
images += glob.glob(os.path.join(self.root, name, '*.gif'))
# 1168, 'pokemon\\bulbasaur\\00000000.png'
print(len(images),images)
# 保存成image,label的csv文件
random.shuffle(images)
with open(os.path.join(self.root, filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images: # 'pokemon\\bulbasaur\\00000000.png'
name = img.split(os.sep)[-2]
label = self.name2label[name]
# 'pokemon\\bulbasaur\\00000000.png', 0
writer.writerow([img, label])
# print('writen into csv file:',filename)
# 加载已保存的csv文件
images, labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
The code to get the dataset size and index element position is:
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# idx:[0, len(self.images)]
# self.images, self.labels
# img:'G:/datasets/pokemon\\charmander\\00000182.png'
# label: 0,1,2,3,4
img, label = self.images[idx], self.labels[idx]
transform = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path => image data
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
transforms.RandomRotation(15), # 随机旋转
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485,0.456,0.406],
# std=[0.229,0.224,0.225])
transforms.Normalize(mean=[0.6096, 0.7286, 0.5103],
std=[1.5543, 1.4887, 1.5958])
])
img = transform(img)
label = torch.tensor(label)
return img, label
Among them, please refer to the calculation of mean and std in transforms.Normalize , or directly use the empirical values mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
The image of batch_size=32 displayed by the Visdom visualization tool is shown in the following figure:
2. Design model
This paper adopts the idea of migration learning, directly uses the resnet18 classifier, retains its first 17 layers of network structure, and modifies the last layer accordingly. The code is as follows:
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1], # [b,512,1,1]
Flatten(), # [b,512,1,1] => [b,512]
nn.Linear(512, 5)
).to(device)
Among them, Flatten() is the data flattening method, the code is as follows:
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
3. Construct the loss function and optimizer
The loss function uses cross entropy, the optimizer uses Adam, and the learning rate is set to 0.001. The code is as follows:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
4. Train, Validate, and Test
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step, (x,y) in enumerate(train_loader):
# x: [b,3,224,224] y: [b]
x, y = x.to(device), y.to(device)
output = model(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
# 验证集
if epoch % 1 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl')
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch+1)
# 加载最好的模型
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)
def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for (x, y) in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
output = model(x)
pred = output.argmax(dim=1)
correct += torch.eq(pred, y).sum().item()
return correct/total
5. Test results
The change curve of the loss value of the training set and the change curve of the accuracy of the test set are shown in the following figure: The
console output is:
best acc: 0.9358974358974359 best epoch: 3
loaded from ckpt!
test acc: 0.9401709401709402
This shows that: when epoch=3, the accuracy of the validation set reaches the highest, and the model at this time can be considered as the best model, and it is used for the test of the test set, reaching an accuracy of 94.02%.
2. References
[1] https://www.bilibili.com/video/BV1f34y1k7fi?p=106
[2] https://blog.csdn.net/Weary_PJ/article/details/122765199