PyTorch implementa la clasificación de conjuntos de datos de Pokémon basada en el aprendizaje de transferencia ResNet18

1. Proceso de implementación

1. Descripción del conjunto de datos

Los conjuntos de datos se dividen en 5 categorías, de la siguiente manera:

  • Pikachu: 234
  • Mewdos: 239
  • Jenny Tortuga: 223
  • Pequeño dragón de fuego: 238
  • Semillas de rana: 234

Enlace de recuperación automática: https://pan.baidu.com/s/1bsppVXDRsweVKAxSoLy4sw
Código de extracción: 9fqo Las
extensiones de archivo de imagen tienen 4 tipos de jpg, jepg, png y gif, y los tamaños de las imágenes no son los mismos, por lo que (es necesario verificar y probar) las imágenes se redimensionan y otras operaciones. En este documento, el tamaño de la imagen se redimensiona a 224×224.

2. Preprocesamiento de datos

Este documento utiliza el marco de conjunto de datos para preprocesar el conjunto de datos y convierte el conjunto de datos de imagen en una relación de mapeo como {imágenes, etiquetas}.

    def __init__(self, root, resize, mode):
        super(Pokemon, self).__init__()

        self.root = root
        self.resize = resize

        self.name2label = {
    
    }    # "sq...": 0
        for name in sorted(os.listdir(os.path.join(root))):
            if not os.path.isdir(os.path.join(root,name)):
                continue
            self.name2label[name] = len(self.name2label.keys())
        # print(self.name2label)

        # image,label
        self.images, self.labels = self.load_csv('images.csv')

        # 数据集裁剪:训练集、验证集、测试集
        if mode == 'train': # 60%
            self.images = self.images[0:int(0.6*len(self.images))]
            self.labels = self.labels[0:int(0.6*len(self.labels))]
        elif mode == 'val': # 20% = 60% -> 80%
            self.images = self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
            self.labels = self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
        else:               # 20% = 80% -> 100%
            self.images = self.images[int(0.8*len(self.images)):]
            self.labels = self.labels[int(0.8*len(self.labels)):]

Entre ellos, root representa el directorio raíz del archivo donde se almacena el conjunto de datos; resize representa el tamaño uniforme de la salida del conjunto de datos; mode representa el modo (tren, val y test) cuando se lee el conjunto de datos; name2label es para construir una estructura de diccionario de nombres y etiquetas de categorías de imágenes, es conveniente obtener la etiqueta de la categoría de imágenes, el método load_csv es crear una relación de mapeo de {imágenes, etiquetas}, donde imágenes representa la ruta del archivo donde se encuentra la imagen, y el código es como sigue:

    def load_csv(self, filename):
        if not os.path.exists(os.path.join(self.root, filename)):
            # 文件不存在,则需要创建该文件
            images = []
            for name in self.name2label.keys():
                # pokemon\\mewtwo\\00001.png
                images += glob.glob(os.path.join(self.root,name,'*.png'))
                images += glob.glob(os.path.join(self.root, name, '*.jpg'))
                images += glob.glob(os.path.join(self.root, name, '*.jpeg'))
                images += glob.glob(os.path.join(self.root, name, '*.gif'))
            # 1168, 'pokemon\\bulbasaur\\00000000.png'
            print(len(images),images)
            # 保存成image,label的csv文件
            random.shuffle(images)
            with open(os.path.join(self.root, filename),mode='w',newline='') as f:
                writer = csv.writer(f)
                for img in images:  # 'pokemon\\bulbasaur\\00000000.png'
                    name = img.split(os.sep)[-2]
                    label = self.name2label[name]
                    # 'pokemon\\bulbasaur\\00000000.png', 0
                    writer.writerow([img, label])
                # print('writen into csv file:',filename)
        # 加载已保存的csv文件
        images, labels = [],[]
        with open(os.path.join(self.root,filename)) as f:
            reader = csv.reader(f)
            for row in reader:
                img, label = row
                label = int(label)
                images.append(img)
                labels.append(label)
        assert len(images) == len(labels)
        return images, labels

El código para obtener el tamaño del conjunto de datos y la posición del elemento de índice es:

    def __len__(self):
        return len(self.images)
    def __getitem__(self, idx):
        # idx:[0, len(self.images)]
        # self.images, self.labels
        # img:'G:/datasets/pokemon\\charmander\\00000182.png'
        # label: 0,1,2,3,4
        img, label = self.images[idx], self.labels[idx]
        transform = transforms.Compose([
            lambda x: Image.open(x).convert('RGB'),  # string path => image data
            transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))),
            transforms.RandomRotation(15),      # 随机旋转
            transforms.CenterCrop(self.resize), # 中心裁剪
            transforms.ToTensor(),
            # transforms.Normalize(mean=[0.485,0.456,0.406],
            #                      std=[0.229,0.224,0.225])
            transforms.Normalize(mean=[0.6096, 0.7286, 0.5103],
                                 std=[1.5543, 1.4887, 1.5958])
        ])

        img = transform(img)
        label = torch.tensor(label)
        return img, label

Entre ellos, consulte el cálculo de media y std en transforms.Normalize , o use directamente los valores empíricos mean=[0.485, 0.456, 0.406] y std=[0.229, 0.224, 0.225].
La imagen de batch_size=32 que muestra la herramienta de visualización Visdom se muestra en la siguiente figura:
inserte la descripción de la imagen aquí

2. Modelo de diseño

Este documento adopta la idea del aprendizaje de la migración, utiliza directamente el clasificador resnet18, conserva sus primeras 17 capas de estructura de red y modifica la última capa en consecuencia. El código es el siguiente:

trained_model = resnet18(pretrained=True)
model = nn.Sequential(*list(trained_model.children())[:-1],     # [b,512,1,1]
                      Flatten(),   # [b,512,1,1] => [b,512]
                      nn.Linear(512, 5)
                      ).to(device)

Entre ellos, Flatten() es el método de aplanamiento de datos, el código es el siguiente:

class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

3. Construya la función de pérdida y el optimizador

La función de pérdida usa entropía cruzada, el optimizador usa Adam y la tasa de aprendizaje se establece en 0. 001. El código es el siguiente:

optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

4. Capacitar, validar y probar

	best_acc, best_epoch = 0, 0
    global_step = 0
    viz.line([0], [-1], win='loss', opts=dict(title='loss'))
    viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
    for epoch in range(epochs):
        for step, (x,y) in enumerate(train_loader):
            # x: [b,3,224,224]  y: [b]
            x, y = x.to(device), y.to(device)
            output = model(x)
            loss = criterion(output, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            viz.line([loss.item()], [global_step], win='loss', update='append')
            global_step += 1
        # 验证集
        if epoch % 1 == 0:
            val_acc = evaluate(model, val_loader)
            if val_acc > best_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(model.state_dict(), 'best.mdl')
                viz.line([val_acc], [global_step], win='val_acc', update='append')

    print('best acc:', best_acc, 'best epoch:', best_epoch+1)
    # 加载最好的模型
    model.load_state_dict(torch.load('best.mdl'))
    print('loaded from ckpt!')
    test_acc = evaluate(model, test_loader)
    print('test acc:', test_acc)
def evaluate(model, loader):
    correct = 0
    total = len(loader.dataset)
    for (x, y) in loader:
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            output = model(x)
            pred = output.argmax(dim=1)
            correct += torch.eq(pred, y).sum().item()
    return correct/total

5. Resultados de la prueba

La curva de cambio del valor de pérdida del conjunto de entrenamiento y la curva de cambio de la precisión del conjunto de prueba se muestran en la siguiente figura: La
inserte la descripción de la imagen aquísalida de la consola es:

best acc: 0.9358974358974359 best epoch: 3
loaded from ckpt!

test acc: 0.9401709401709402

Esto muestra que: cuando epoch=3, la precisión del conjunto de validación alcanza el máximo, y el modelo en este momento puede considerarse como el mejor modelo, y se utiliza para la prueba del conjunto de prueba, alcanzando una precisión de 94,02 %

2. Referencias

[1] https://www.bilibili.com/video/BV1f34y1k7fi?p=106
[2] https://blog.csdn.net/Weary_PJ/article/details/122765199

Supongo que te gusta

Origin blog.csdn.net/weixin_43821559/article/details/123561478
Recomendado
Clasificación