Clasificación de imágenes de flores de Resnet y Pytorch

1. Introducción

1.1 Introducción al conjunto de datos

flower_data
    ├── train
    │   └── 1-102(102个文件夹)
    │   	└── XXX.jpg(每个文件夹含若干张图像)
    ├── valid
    │   └── 1-102(102个文件夹)
    └── ───	└── XXX.jpg(每个文件夹含若干张图像)  
     
cat_to_name.json:每一类花朵的"名称-编号"对应关系

1.2 Introducción a la tarea

Realice la tarea de clasificación de 102 tipos de flores, es decir, después de pasar el trainconjunto de datos de entrenamiento, validseleccione una determinada imagen de flor del conjunto de datos y pueda determinar con precisión a qué tipo de flor pertenece.

1.3Introducción a Resnet

Hay dos aspectos destacados en la red ResNet:

  1. Proponer una estructura residual (estructura residual) y construir una estructura de red ultra profunda (superar 1000 capas)
  2. Utilice la normalización por lotes para acelerar el entrenamiento (descarte el abandono)

Antes de que se propusiera la red ResNet, la red neuronal convolucional tradicional se obtenía apilando una serie de capas convolucionales y capas de muestreo reducido. Pero al apilar a una determinada profundidad de red, habrá dos problemas:

  1. Degradados que desaparecen o explotan
  2. problema de degradación

2. Preprocesamiento de datos

2.1 Importar el archivo de encabezado

import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torch.optim as optim
import torchvision
from torchvision import transforms,models,datasets
import imageio
import time
import warnings
import random
import sys
import copy
import json
from PIL import Image

2.2 Lectura de datos

#数据读取与预处理操作
data_dir = './flower_data/'
# 训练集
train_dir = data_dir + '/train'
#验证集
valid_ir = data_dir + '/valid'

2.3 Crear fuente de datos

#制作数据源
data_transfroms = {
    'train':transforms.Compose([transforms.RandomRotation(45), #随机旋转(-45~45)
    transforms.CenterCrop(224), #从中心开始裁剪
    transforms.RandomHorizontalFlip(p = 0.5), #随机水平翻转
    transforms.RandomVerticalFlip(p = 0.5), #随机垂直翻转
    transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue = 0.1),
    transforms.RandomGrayscale(p = 0.025), #概率转换成灰度率,3通道就是R=G=B
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
    'valid':transforms.Compose([transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}

2.4 producción de datos por lotes

#batch数据制作
batch_size = 8
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir,x),data_transfroms[x]) for x in ['train','valid']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],batch_size = batch_size,shuffle = True) for x in ['train','valid']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train','valid']}
class_names = image_datasets['train'].classes

2.5 Leer etiqueta de datos

#读取标签对应的实际名字
with open('cat_to_name.json','r') as f:
    cat_to_name = json.load(f)

Vea el archivo cat_to_name.json:

{'21': 'fire lily',
 '3': 'canterbury bells',
 '45': 'bolero deep blue',
 '1': 'pink primrose',
 '34': 'mexican aster',
 '27': 'prince of wales feathers',
 '7': 'moon orchid',
 '16': 'globe-flower',
 '25': 'grape hyacinth',
 '26': 'corn poppy',
 '79': 'toad lily',
 '39': 'siam tulip',
 '24': 'red ginger',
 '67': 'spring crocus',
 '35': 'alpine sea holly',
 '32': 'garden phlox',
 '10': 'globe thistle',
 '6': 'tiger lily',
 '93': 'ball moss',
 '33': 'love in the mist',
 '9': 'monkshood',
 '102': 'blackberry lily',
 '14': 'spear thistle',
 '19': 'balloon flower',
 '100': 'blanket flower',
 '13': 'king protea',
 '49': 'oxeye daisy',
 '15': 'yellow iris',
 '61': 'cautleya spicata',
 '31': 'carnation',
 '64': 'silverbush',
 '68': 'bearded iris',
 '63': 'black-eyed susan',
 '69': 'windflower',
 '62': 'japanese anemone',
 '20': 'giant white arum lily',
 '38': 'great masterwort',
 '4': 'sweet pea',
 '86': 'tree mallow',
 '101': 'trumpet creeper',
 '42': 'daffodil',
 '22': 'pincushion flower',
 '2': 'hard-leaved pocket orchid',
 '54': 'sunflower',
 '66': 'osteospermum',
 '70': 'tree poppy',
 '85': 'desert-rose',
 '99': 'bromelia',
 '87': 'magnolia',
 '5': 'english marigold',
 '92': 'bee balm',
 '28': 'stemless gentian',
 '97': 'mallow',
 '57': 'gaura',
 '40': 'lenten rose',
 '47': 'marigold',
 '59': 'orange dahlia',
 '48': 'buttercup',
 '55': 'pelargonium',
 '36': 'ruby-lipped cattleya',
 '91': 'hippeastrum',
 '29': 'artichoke',
 '71': 'gazania',
 '90': 'canna lily',
 '18': 'peruvian lily',
 '98': 'mexican petunia',
 '8': 'bird of paradise',
 '30': 'sweet william',
 '17': 'purple coneflower',
 '52': 'wild pansy',
 '84': 'columbine',
 '12': "colt's foot",
 '11': 'snapdragon',
 '96': 'camellia',
 '23': 'fritillary',
 '50': 'common dandelion',
 '44': 'poinsettia',
 '53': 'primula',
 '72': 'azalea',
 '65': 'californian poppy',
 '80': 'anthurium',
 '76': 'morning glory',
 '37': 'cape flower',
 '56': 'bishop of llandaff',
 '60': 'pink-yellow dahlia',
 '82': 'clematis',
 '58': 'geranium',
 '75': 'thorn apple',
 '41': 'barbeton daisy',
 '95': 'bougainvillea',
 '43': 'sword lily',
 '83': 'hibiscus',
 '78': 'lotus lotus',
 '88': 'cyclamen',
 '94': 'foxglove',
 '81': 'frangipani',
 '74': 'rose',
 '89': 'watercress',
 '73': 'water lily',
 '46': 'wallflower',
 '77': 'passion flower',
 '51': 'petunia'}

3. Visualización de datos

3.1 Función de procesamiento de imágenes

#展示数据
def im_convert(tensor):
    image = tensor.to("cpu").clone().detach()
    image = image.numpy().squeeze()
    image = image.transpose(1,2,0)
    image = image * np.array((0.229,0.224,0.225)) + np.array((0.485,0.456,0.406))
    image = image.clip(0.1)

    return image

3.2 Mostrar imágenes

fig=plt.figure(figsize=(20, 12))
columns = 4
rows = 2

dataiter = iter(dataloaders['valid'])
inputs, classes = dataiter.next()

for idx in range (columns*rows):
    ax = fig.add_subplot(rows, columns, idx+1, xticks=[], yticks=[])
    ax.set_title(cat_to_name[str(int(class_names[classes[idx]]))])
    plt.imshow(im_convert(inputs[idx]))
plt.show()

4. Realizar transferencia de aprendizaje

Puntos clave para el aprendizaje por transferencia:

  • Investigar qué conocimientos se pueden utilizar para la transferencia de aprendizaje en diferentes campos o tareas, es decir, qué conocimientos comunes se pueden transferir entre diferentes campos.
  • Una vez encontrado el objeto de transferencia, qué algoritmo específico de aprendizaje por transferencia se utiliza para problemas específicos, es decir, cómo diseñar un algoritmo adecuado para extraer y transferir conocimiento común.
  • Investigación sobre qué circunstancias son adecuadas para la migración y si la técnica de migración es adecuada para aplicaciones específicas, lo que implica el tema de la transferencia negativa.

4.1 Capacitación de capa totalmente conectada

Cargue el modelo proporcionado en los modelos y utilice directamente el peso entrenado como parámetro de inicialización. 

Enlace de descarga: https://download.pytorch.org/models/resnet152-394f9c45.pth

 Seleccione la red resnet

model_name = 'resnet'  #可选的有: ['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception']

#是否用官方训练好的特征来做
feature_extract = True 

Configurar el entrenamiento con GPU

#是否用GPU来训练
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('cuda is not available. Training on CPU')
else:
    print('cuda is available. Training on GPU')

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Enmascare los pesos del modelo previamente entrenado y entrene solo los pesos de la capa completamente conectada: 

def set_parameter_requires_grad(model,feature_extracting):
    if feature_extracting:
        for param in model.parameter():
            param.requires_grad = False

Seleccione la red resnet152

model_ft = models.resnet152()

Configure el optimizador:

#优化器设置
optimizer_ft = optim.Adam(params_to_update,lr = 1e-2)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1) #学习率每7个epoch衰减成原来的1/10
criterion = nn.NLLLoss()

Definir el módulo de formación:

# 训练模块
def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False,filename = filename):
    since = time.time()
    best_acc = 0

    model.to(device)
    val_acc_history = []
    train_acc_history = []
    train_losses = []
    valid_losses = []
    LRs = [optimizer.param_groups[0]['lr']]

    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        print('Epoch {} / {}'.format(epoch,num_epochs - 1))
        print('-' * 10)

        #训练与验证
        for phase in ['train','valid']:
            if phase == 'train':
                model.train()  #训练
            else:
                model.eval()  #验证

            running_loss = 0.0
            running_corrects = 0

            #把数据取个遍
            for inputs,labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                #清零
                optimizer.zero_grad()

                #只有训练的时候计算与更新梯度
                with torch.set_grad_enabled(phase == 'train'):
                    if is_inception and phase == 'train':
                        outputs,aux_outputs = model(inputs)
                        loss1 = criterion(outputs,labels)
                        loss2 = criterion(aux_outputs,labels)
                        loss = loss1 + 0.4 * loss2
                    else: #resnet执行的是这里
                        outputs = model(inputs)
                        loss = criterion(outputs,labels)
                        _, preds = torch.max(outputs,1)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                #计算损失
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            time_elapsed = time.time() - since
            print('Time elapsed {:.0f}m {:.0f}f'.format(time_elapsed // 60,time_elapsed % 60))
            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase,epoch_loss,epoch_acc))

            #得到最好的模型
            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                state = {
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer':optimizer.state_dict(),
                }
                torch.save(state,filename)
                if phase == 'valid':
                    val_acc_history.append(epoch_acc)
                    valid_losses.append(epoch_loss)
                    scheduler.step(epoch_loss)
                if phase == 'train':
                    train_acc_history.append(epoch_acc)
                    train_losses.append(epoch_loss)

        print('Optimizer learning rate : {:.7f}'.format(optimizer.param_groups[0]['lr']))
        LRs.append(optimizer.param_groups[0]['lr'])
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed //60,time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    #训练完后用最好的一次当做模型最终的结果
    model.load_state_dict(best_model_wts)
    return model,val_acc_history,train_acc_history.valid_losses,train_losses,LRs

Empezar a entrenar:

# 开始训练
model_ft,val_acc_history,train_acc_history,valid_lossea,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=20,is_inception=(model_name == 'inception'))

4.2 Entrenando todas las capas

Comenzamos con los parámetros de la capa mejor completamente conectada entrenada la última vez y entrenamos todas las capas en función de esto. La configuración param.requires_grad = Trueindica que a continuación se entrenará toda la red y luego se reducirá la tasa de aprendizaje. La función de caída es cada 7 veces 1/10 de, la función de pérdida permanece sin cambios

再继续训练所有层
for param in model_ft.parameters():
    param.requires_grad = True

#再继续训练所有的参数,学习率调小一点(lr)
optimizer = optim.Adam(params_to_update,lr = 1e-4)
#衰减函数(每七次衰减为原来的七分之一)
scheduler = optim.lr_scheduler.StepLR(optimizer_ft,step_size=7,gamma=0.1)

#损失函数
criterion = nn.NLLLoss()

 Importa los mejores resultados anteriores y comienza a entrenar:

#在之前训练得到最好的模型的基础上继续训练
checkpoint = torch.load(filename)
best_acc = checkpoint['best_acc']
model_ft.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

model_ft,val_acc_history,train_acc_history,valid_lossea,train_losses,LRs = train_model(model_ft,dataloaders,criterion,optimizer_ft,num_epochs=10,is_inception=(model_name == 'inception'))

5. Pruebe el efecto de red

5.1 Preprocesamiento de datos de prueba

Primero cambie el nombre del checkpoint.pthmodelo recién entrenado a serious.pthy luego cargue el modelo entrenado:

#加载训练好的模型
model_ft,input_size = initialize_model(model_name,102,feature_extract,use_pretrained=True)

#GPU模型
model_ft = model_ft.to(device)
#保存文件的名字
filename = 'serious.pth'
#加载模型
checkpoint = torch.load(filename)
best_acc = checkpoint['beat_acc']
model_ft.load_state_dict(checkpoint['state_dict'])

Defina la función de procesamiento de imágenes:

def process_image(image_path):
    img = Image.open(image_path)

    #Resize,thumbnail方法只能进行缩小,所以进行判断
    if img.size[0] > img.size[1]:
        img.thumbnail((10000,256))
    else:
        img.thumbnail((256,10000))

    #Crop操作
    left_margin = (img.width-224)/2
    bottom_margin = (img.height-224)/2
    right_margin = (left_margin) + 224
    top_margin = bottom_margin + 224
    img  = img.crop(left_margin,bottom_margin,right_margin,top_margin)

    #相同的预处理方法
    img = np.array(img)/255
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    img = (img - mean)/std

    #注意颜色通道应该放在第一个位置
    img = img.transpose((2,0,1))

    return img

Defina la función de visualización de imágenes:

#展示数据
def imshow(image,ax = None,title = None):
    if ax is None:
        fig,ax = plt.subplots()

    #颜色通道还原
    image = np.array(image).transpose((1,2,0))

    #预处理还原
    mean = np.array([0.485,0.456,0.406])
    std = np.array([0.229,0.224,0.225])
    image = std * image + mean
    image = np.clip(image,0.1)

    ax.imshow(image)
    ax.set_title(title)

    return ax

Mostrar un dato:

image_path = 'image_06621.jpg'
img = process_image(image_path)
imshow(img)

 

Obtenga datos de prueba por lotes:

#测试一个batch数据
dataiter = iter(dataloaders['valid'])
images,labels = dataiter.next()

model_ft.eval()

if train_on_gpu:
    output = model_ft(images.cuda())
else:
    output = model_ft(images)

Utilice la función torch.max() para calcular el valor de la etiqueta:

#得到属于类别的八个编号
_,preds_tensor = torch.ax(output,1)
preds = np.squeeze(preds_tensor.numpy()) if not train_on_gpu else np.squeeze(preds_tensor.cpu().numpy())

5.2 Visualización de resultados

#展示预测结果
fig = plt.figure(figsize=(20,20))
columns = 4
rows = 2

for idx in range(columns * rows):
    ax = fig.add_subplot(rows,columns,idx+1,xticks=[],yticks=[])
    plt.imshow(im_convert(images[idx]))
    ax.set_title("{} {}".format(cat_to_name[str(preds[idx])],cat_to_name[str(labels[idx].item())]),
                 color = ("green" if cat_to_name[str(preds[idx])] == cat_to_name[str(labels[idx].item())] else "red"))
plt.show()

Los resultados son los siguientes (el título verde representa el éxito del reconocimiento, el título rojo representa el error de reconocimiento, el valor real dentro de los corchetes y el valor previsto fuera de los corchetes)

Supongo que te gusta

Origin blog.csdn.net/weixin_64443786/article/details/132029099
Recomendado
Clasificación