1. Proceso de implementación
1. Descripción del conjunto de datos
Los conjuntos de datos se dividen en 5 categorías, de la siguiente manera:
- Pikachu: 234
- Mewdos: 239
- Jenny Tortuga: 223
- Pequeño dragón de fuego: 238
- Semillas de rana: 234
Enlace de recuperación automática: https://pan.baidu.com/s/1bsppVXDRsweVKAxSoLy4sw
Código de extracción: 9fqo Las
extensiones de archivo de imagen tienen 4 tipos de jpg, jepg, png y gif, y los tamaños de las imágenes no son los mismos, por lo que (es necesario verificar y probar) las imágenes se redimensionan y otras operaciones. En este documento, el tamaño de la imagen se redimensiona a 224×224.
2. Preprocesamiento de datos
Este documento utiliza el marco de conjunto de datos para preprocesar el conjunto de datos y convierte el conjunto de datos de imagen en una relación de mapeo como {imágenes, etiquetas}.
def __init__(self, root, resize, mode):
super(Pokemon, self).__init__()
self.root = root
self.resize = resize
self.name2label = {
} # "sq...": 0
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name] = len(self.name2label.keys())
# print(self.name2label)
# image,label
self.images, self.labels = self.load_csv('images.csv')
# 数据集裁剪:训练集、验证集、测试集
if mode == 'train': # 60%
self.images = self.images[0:int(0.6*len(self.images))]
self.labels = self.labels[0:int(0.6*len(self.labels))]
elif mode == 'val': # 20% = 60% -> 80%
self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else: # 20% = 80% -> 100%
self.images = self.images[int(0.8*len(self.images)):]
self.labels = self.labels[int(0.8*len(self.labels)):]
Entre ellos, root representa el directorio raíz del archivo donde se almacena el conjunto de datos; resize representa el tamaño uniforme de la salida del conjunto de datos; mode representa el modo (tren, val y test) cuando se lee el conjunto de datos; name2label es para construir una estructura de diccionario de nombres y etiquetas de categorías de imágenes, es conveniente obtener la etiqueta de la categoría de imágenes, el método load_csv es crear una relación de mapeo de {imágenes, etiquetas}, donde imágenes representa la ruta del archivo donde se encuentra la imagen, y el código es como sigue:
def load_csv(self, filename):
if not os.path.exists(os.path.join(self.root, filename)):
# 文件不存在,则需要创建该文件
images = []
for name in self.name2label.keys():
# pokemon\\mewtwo\\00001.png
images += glob.glob(os.path.join(self.root,name,'*.png'))
images += glob.glob(os.path.join(self.root, name, '*.jpg'))
images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
images += glob.glob(os.path.join(self.root, name, '*.gif'))
# 1168, 'pokemon\\bulbasaur\\00000000.png'
print(len(images),images)
# 保存成image,label的csv文件
random.shuffle(images)
with open(os.path.join(self.root, filename),mode='w',newline='') as f:
writer = csv.writer(f)
for img in images: # 'pokemon\\bulbasaur\\00000000.png'
name = img.split(os.sep)[-2]
label = self.name2label[name]
# 'pokemon\\bulbasaur\\00000000.png', 0
writer.writerow([img, label])
# print('writen into csv file:',filename)
# 加载已保存的csv文件
images, labels = [],[]
with open(os.path.join(self.root,filename)) as f:
reader = csv.reader(f)
for row in reader:
img, label = row
label = int(label)
images.append(img)
labels.append(label)
assert len(images) == len(labels)
return images, labels
El código para obtener el tamaño del conjunto de datos y la posición del elemento de índice es:
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
# idx:[0, len(self.images)]
# self.images, self.labels
# img:'G:/datasets/pokemon\\charmander\\00000182.png'
# label: 0,1,2,3,4
img, label = self.images[idx], self.labels[idx]
transform = transforms.Compose([
lambda x: Image.open(x).convert('RGB'), # string path => image data
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
transforms.RandomRotation(15), # 随机旋转
transforms.CenterCrop(self.resize), # 中心裁剪
transforms.ToTensor(),
# transforms.Normalize(mean=[0.485,0.456,0.406],
# std=[0.229,0.224,0.225])
transforms.Normalize(mean=[0.6096, 0.7286, 0.5103],
std=[1.5543, 1.4887, 1.5958])
])
img = transform(img)
label = torch.tensor(label)
return img, label
Entre ellos, consulte el cálculo de media y std en transforms.Normalize , o use directamente los valores empíricos mean=[0.485, 0.456, 0.406] y std=[0.229, 0.224, 0.225].
La imagen de batch_size=32 que muestra la herramienta de visualización Visdom se muestra en la siguiente figura:
2. Modelo de diseño
Este documento adopta la idea del aprendizaje de la migración, utiliza directamente el clasificador resnet18, conserva sus primeras 17 capas de estructura de red y modifica la última capa en consecuencia. El código es el siguiente:
trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1], # [b,512,1,1]
Flatten(), # [b,512,1,1] => [b,512]
nn.Linear(512, 5)
).to(device)
Entre ellos, Flatten() es el método de aplanamiento de datos, el código es el siguiente:
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, x):
shape = torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1, shape)
3. Construya la función de pérdida y el optimizador
La función de pérdida usa entropía cruzada, el optimizador usa Adam y la tasa de aprendizaje se establece en 0. 001. El código es el siguiente:
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
4. Capacitar, validar y probar
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step, (x,y) in enumerate(train_loader):
# x: [b,3,224,224] y: [b]
x, y = x.to(device), y.to(device)
output = model(x)
loss = criterion(output, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
# 验证集
if epoch % 1 == 0:
val_acc = evaluate(model, val_loader)
if val_acc > best_acc:
best_acc = val_acc
best_epoch = epoch
torch.save(model.state_dict(), 'best.mdl')
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch+1)
# 加载最好的模型
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evaluate(model, test_loader)
print('test acc:', test_acc)
def evaluate(model, loader):
correct = 0
total = len(loader.dataset)
for (x, y) in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
output = model(x)
pred = output.argmax(dim=1)
correct += torch.eq(pred, y).sum().item()
return correct/total
5. Resultados de la prueba
La curva de cambio del valor de pérdida del conjunto de entrenamiento y la curva de cambio de la precisión del conjunto de prueba se muestran en la siguiente figura: La
salida de la consola es:
best acc: 0.9358974358974359 best epoch: 3
loaded from ckpt!
test acc: 0.9401709401709402
Esto muestra que: cuando epoch=3, la precisión del conjunto de validación alcanza el máximo, y el modelo en este momento puede considerarse como el mejor modelo, y se utiliza para la prueba del conjunto de prueba, alcanzando una precisión de 94,02 %
2. Referencias
[1] https://www.bilibili.com/video/BV1f34y1k7fi?p=106
[2] https://blog.csdn.net/Weary_PJ/article/details/122765199