El aprendizaje profundo completa la clasificación de imágenes basada en la red ResNet18

una. prefacio

Esta tarea consiste en utilizar la red ResNet18 para practicar una tarea de clasificación de imágenes más general.

La serie de redes ResNet, algoritmos bien conocidos en el campo de la clasificación de imágenes, son duraderos y atemporales, y todavía tienen una amplia gama de escenarios de aplicación e importancia para la investigación hasta el día de hoy. La industria ha realizado varias mejoras y, a menudo, se utilizan para tareas de reconocimiento de imágenes.

Hoy, presento principalmente el caso de la estructura de red ResNet-18, y a su vez se pueden deducir otras redes profundas.

ResNet-18, el número representa la profundidad de la red, es decir, ¿la red ResNet18 tiene 18 capas? De hecho, 18 aquí especifica 18 capas con pesos, incluidas las capas convolucionales y las capas totalmente conectadas, excluyendo las capas de agrupación y las capas BN.

La clasificación de imágenes es una tarea básica en la visión por computadora, que divide diferentes imágenes en diferentes categorías según la semántica de las imágenes. Muchas tareas también se pueden convertir en tareas de clasificación de imágenes. Por ejemplo, la detección de rostros es para determinar si hay un rostro en un área, lo que puede considerarse como una tarea de clasificación de imágenes binarias.

  • Conjunto de datos: el clásico conjunto de datos CIFAR-10 utilizado en el campo de la visión por computadora
  • Capa de red: la red es el modelo ResNet18
  • Optimizer: El optimizador es el optimizador de Adam
  • Función de pérdida: la función de pérdida es pérdida de entropía cruzada
  • Índice de evaluación: El índice de evaluación es la tasa de precisión

Introducción a la red ResNet:

inserte la descripción de la imagen aquí

dos. Preprocesamiento de datos

2.1 Introducción del conjunto de datos

El conjunto de datos CIFAR-10 contiene 10 categorías diferentes con un total de 60 000 imágenes, de las cuales cada categoría tiene 6000 imágenes y el tamaño de la imagen es de 32 × 3232 × 32 píxeles.

2.2 Lectura de datos

En este experimento, el conjunto de entrenamiento original se divide en dos partes, train_set y dev_set, que incluyen 40 000 y 10 000 muestras, respectivamente. Tome data_batch_1 a data_batch_4 como conjunto de entrenamiento, data_batch_5 como conjunto de validación y test_batch como conjunto de prueba. El conjunto de datos final consta de:

  • Conjunto de entrenamiento: 40 000 muestras.
  • Conjunto de validación: 10 000 muestras.
  • Conjunto de prueba: 10 000 muestras.

El código para leer un lote de datos es el siguiente:

import os
import pickle
import numpy as np

def load_cifar10_batch(folder_path, batch_id=1, mode='train'):
    if mode == 'test':
        file_path = os.path.join(folder_path, 'test_batch')
    else:
        file_path = os.path.join(folder_path, 'data_batch_'+str(batch_id))

    #加载数据集文件
    with open(file_path, 'rb') as batch_file:
        batch = pickle.load(batch_file, encoding = 'latin1')

    imgs = batch['data'].reshape((len(batch['data']),3,32,32)) / 255.
    labels = batch['labels']

    return np.array(imgs, dtype='float32'), np.array(labels)

imgs_batch, labels_batch = load_cifar10_batch(folder_path='datasets/cifar-10-batches-py', 
                                                batch_id=1, mode='train')

Ver las dimensiones de los datos:

#打印一下每个batch中X和y的维度
print ("batch of imgs shape: ",imgs_batch.shape, "batch of labels shape: ", labels_batch.shape)

lote de forma de imgs: (10000, 3, 32, 32) lote de forma de etiquetas: (10000,)

Observa visualmente una de las imágenes de muestra y las etiquetas correspondientes, el código es el siguiente:

%matplotlib inline
import matplotlib.pyplot as plt

image, label = imgs_batch[1], labels_batch[1]
print("The label in the picture is {}".format(label))
plt.figure(figsize=(2, 2))
plt.imshow(image.transpose(1,2,0))
plt.savefig('cnn-car.pdf')

2.3 Construyendo la clase Dataset

Construya una clase CIFAR10Dataset, que heredará de la paddle.io.DataSetclase y puede procesar los datos uno por uno. El código se implementa de la siguiente manera:

import paddle
import paddle.io as io
from paddle.vision.transforms import Normalize

class CIFAR10Dataset(io.Dataset):
    def __init__(self, folder_path='/home/aistudio/cifar-10-batches-py', mode='train'):
        if mode == 'train':
            #加载batch1-batch4作为训练集
            self.imgs, self.labels = load_cifar10_batch(folder_path=folder_path, batch_id=1, mode='train')
            for i in range(2, 5):
                imgs_batch, labels_batch = load_cifar10_batch(folder_path=folder_path, batch_id=i, mode='train')
                self.imgs, self.labels = np.concatenate([self.imgs, imgs_batch]), np.concatenate([self.labels, labels_batch])
        elif mode == 'dev':
            #加载batch5作为验证集
            self.imgs, self.labels = load_cifar10_batch(folder_path=folder_path, batch_id=5, mode='dev')
        elif mode == 'test':
            #加载测试集
            self.imgs, self.labels = load_cifar10_batch(folder_path=folder_path, mode='test')
        self.transform = Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010], data_format='CHW')

    def __getitem__(self, idx):
        img, label = self.imgs[idx], self.labels[idx]
        img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.imgs)

paddle.seed(100)
train_dataset = CIFAR10Dataset(folder_path='datasets/cifar-10-batches-py', mode='train')
dev_dataset = CIFAR10Dataset(folder_path='datasets/cifar-10-batches-py', mode='dev')
test_dataset = CIFAR10Dataset(folder_path='datasets/cifar-10-batches-py', mode='test')

3. Construcción de modelos

Experimentos de clasificación de imágenes utilizando Resnet18 en la API de alto nivel de Flying Paddle.

from paddle.vision.models import resnet18

resnet18_model = resnet18()

La API de alto nivel de la paleta es una encapsulación y actualización adicional de la API de la paleta, proporcionando una API más concisa y fácil de usar, que mejora aún más la facilidad de aprendizaje y uso de la paleta. Entre ellos, la API de alto nivel de Flying Paddle encapsula los siguientes módulos:

  1. Clase de modelo, que admite el entrenamiento de modelos con solo unas pocas líneas de código;
  2. Módulo de preprocesamiento de imágenes, que incluye docenas de funciones de procesamiento de datos, que básicamente cubren métodos comunes de procesamiento y mejora de datos;
  3. Modelos comunes en el campo de la visión artificial y el procesamiento del lenguaje natural, incluidos, entre otros, mobilenet, resnet, yolov3, cyclegan, bert, transformer, seq2seq, etc. Al mismo tiempo, se lanzan los modelos preentrenados de los modelos correspondientes. , y estos modelos se pueden utilizar directamente o aquí Sobre la base de la finalización del desarrollo secundario.

4. Entrenamiento modelo

Reutilice la clase RunnerV3, cree una instancia de la clase RunnerV3 y pase la configuración de entrenamiento. El entrenamiento del modelo se realiza utilizando el conjunto de entrenamiento y el conjunto de validación para un total de 30 épocas. En los experimentos, guarde el modelo con la mayor precisión como el mejor modelo. El código se implementa de la siguiente manera:

import paddle.nn.functional as F
import paddle.optimizer as opt
from nndl import RunnerV3, Accuracy

#指定运行设备
use_gpu = True if paddle.get_device().startswith("gpu") else False
if use_gpu:
    paddle.set_device('gpu:0')
#学习率大小
lr = 0.001  
#批次大小
batch_size = 64 
#加载数据
train_loader = io.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dev_loader = io.DataLoader(dev_dataset, batch_size=batch_size)
test_loader = io.DataLoader(test_dataset, batch_size=batch_size) 
#定义网络
model = resnet18_model
#定义优化器,这里使用Adam优化器以及l2正则化策略,相关内容在7.3.3.2和7.6.2中会进行详细介绍
optimizer = opt.Adam(learning_rate=lr, parameters=model.parameters(), weight_decay=0.005)
#定义损失函数
loss_fn = F.cross_entropy
#定义评价指标
metric = Accuracy(is_logist=True)
#实例化RunnerV3
runner = RunnerV3(model, optimizer, loss_fn, metric)
#启动训练
log_steps = 3000
eval_steps = 3000
runner.train(train_loader, dev_loader, num_epochs=30, log_steps=log_steps, 
                eval_steps=eval_steps, save_path="best_model.pdparams")

Observe visualmente los cambios de precisión y pérdida del conjunto de entrenamiento y el conjunto de validación.

from nndl import plot

plot(runner, fig_name='cnn-loss4.pdf')

En este experimento, se utiliza el optimizador de Adam presentado en el Capítulo 7 para la optimización de la red. Si se utiliza el optimizador SGD, se producirá el fenómeno de sobreajuste y no se podrá obtener un buen efecto de convergencia en el conjunto de validación. Puede intentar ajustar la configuración de entrenamiento usando otras estrategias de optimización en el Capítulo 7 para lograr una mayor precisión del modelo.

V. Evaluación del Modelo

Utilice los datos de prueba para evaluar el mejor modelo guardado durante el proceso de entrenamiento y observe la precisión y la pérdida del modelo en el conjunto de prueba. El código se implementa de la siguiente manera:

# 加载最优模型
runner.load_model('best_model.pdparams')
# 模型评价
score, loss = runner.evaluate(test_loader)
print("[Test] accuracy/loss: {:.4f}/{:.4f}".format(score, loss))

[Prueba] precisión/pérdida: 0,7234/0,8324

6. Modelo de predicción

De manera similar, también puede usar el modelo guardado para hacer predicciones de modelo sobre los datos en el conjunto de prueba y observar el efecto del modelo. El código específico se implementa de la siguiente manera:

#获取测试集中的一个batch的数据
X, label = next(test_loader())
logits = runner.predict(X)
#多分类,使用softmax计算预测概率
pred = F.softmax(logits)
#获取概率最大的类别
pred_class = paddle.argmax(pred[2]).numpy()
label = label[2][0].numpy()
#输出真实类别与预测类别
print("The true category is {} and the predicted category is {}".format(label[0], pred_class[0]))
#可视化图片
plt.figure(figsize=(2, 2))
imgs, labels = load_cifar10_batch(folder_path='/home/aistudio/datasets/cifar-10-batches-py', mode='test')
plt.imshow(imgs[2].transpose(1,2,0))
plt.savefig('cnn-test-vis.pdf')

La verdadera categoría es 8 y la categoría predicha es 8

El real es 8, la predicción es 8. Embarcacion

Supongo que te gusta

Origin blog.csdn.net/m0_59596937/article/details/127354485
Recomendado
Clasificación