Combate real de Pytorch [realizar clasificación de imágenes]

Implementar Pytorch para completar la clasificación de categorías


Objeto

  • Dominio básico del uso del framework pytorch para tareas de entrenamiento de redes neuronales
  • Use Pycharm, Google Colab para completar la escritura de código
  • Este experimento es solo para familiarizarse con el proceso de entrenamiento, por lo que el modelo es relativamente simple.

1. Escribir código

Introducción al conjunto de datos

El conjunto de datos CIFAR-10 contiene 6000 datos de imagen de tamaño (32,32), con 10 categorías. El conjunto de entrenamiento tiene 5000 imágenes y el conjunto de prueba tiene 1000 imágenes.

inserte la descripción de la imagen aquí

Lectura de datos y carga de datos.

# 创建一个transform
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])
# 准备数据
# 参数 train=True 表示是训练数据 ,False是测试数据
train_data = torchvision.datasets.CIFAR10("./dataset", train=True, transform=transform,
                                          download=False)

test_data = torchvision.datasets.CIFAR10("./pytorch/dataset", train=False, transform=transform,
                                         download=False)
# 加载数据
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

Estructura de directorios

inserte la descripción de la imagen aquí

  • La red está escrita en la estructura de red de vgg16

La arquitectura de la red es la siguiente

inserte la descripción de la imagen aquí

el código

import torch
from torch import nn

# 定义网路结构
class VGG16(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Sequential(

            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )

    def forward(self, input):
        output = self.model(input)
        return output
if __name__ == '__main__':
    mymodel =VGG16()
    input = torch.ones((64,3,32,32))
    output = mymodel(input)
    print(output.shape)

plot_util.py

import matplotlib.pyplot as plt
import seaborn as sns

# 画出train图线
def plot(train_loss):
    # sns.set()
    sns.set_style("dark")
    # sns.despine()

    idx_list = [i for i in range(len(train_loss))]

    plt.figure(figsize=(10, 6))
    plt.rcParams["font.size"] = 18
    plt.grid(visible=True, which='major', linestyle='-')
    plt.grid(visible=True, which='minor', linestyle='--', alpha=0.5)
    # 显示小刻度  minorticks_off()不显示
    plt.minorticks_on()

    plt.plot(idx_list, train_loss, 'o-', color='red', marker='*', linewidth=1, fillstyle='bottom')

    plt.title("traning loss")
    plt.xlabel("train times")
    plt.ylabel("train loss")
    plt.legend(["positive", "commend"])
    plt.savefig("train_loss2.png")
    # plt.show()
    plt.close()

tren

  • definir parámetros
  • modelo de carga
  • guardar modelo
  • Dibuja la función train_loss
  • De manera predeterminada, el archivo .pth del modelo entrenado se carga desde el directorio del modelo cada vez y se selecciona la carga con el subíndice más grande.
def train(model,maxepoch=20) :
    mynetwork = model
    # 定义损失函数
    loss_fn = nn.CrossEntropyLoss().to(device)
    # 定义学习率
    learning_rate = 0.01
    # 优化器
    optimizer = torch.optim.SGD(mynetwork.parameters(), learning_rate)

    # 设置训练网络的参数
    total_train_step = 0
    total_test_step = 0
    # 训练轮数
    epoch = 0
    max_epoch = maxepoch
    train_loss = []
    test_accuaacy = []
    state = {
    
    'model':mynetwork.state_dict(),
             'optimizer':optimizer.state_dict(),
             'epoch':epoch
             }
    model_save_path = './result/model/'
    model_load_path = './result/model/'
    # 从加载model的路径下获取所有文件(如果是.pth后缀的文件)
    model_files = [file for file in os.listdir(model_load_path) if file.endswith('.pth') ]
    model_files.sort(key =lambda x :int((x.split('.')[0]).split('_')[1]))
    # maxx = int ((model_files[-1].split('.')[0]).split('_')[1])
    # 如果大于0 ,就可以加载
    if len(model_files) >0 :
        path = model_load_path+model_files[-1]
        checkpoint = torch.load(path)
        mynetwork.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = int ((model_files[-1].split('.')[0]).split('_')[1])
        print('----load model -----')


    for i in range(epoch,max_epoch):
        print("[----------- {} epoch train ------------]".format(i + 1))
        mynetwork.train()
        for data in train_dataloader:
            imgs, targets = data
            imgs = imgs.to(device)
            targets = targets.to(device)
            outputs = mynetwork(imgs)
            loss = loss_fn(outputs, targets)

            # 优化器
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_train_step += 1
            if total_train_step % 100 == 0:
                print("the {} times train and loss : {} ".format(total_train_step, loss.item()))
            train_loss.append(loss.item())

        # 保存训练模型
        current_train_model_name = "model_{}.pth".format(i+1)
        torch.save(state,model_save_path+current_train_model_name)

        # 测试
        mynetwork.eval()
        total_test_loss = 0
        total_accuracy = 0
        with torch.no_grad():
            for data in test_dataloader:
                imgs, targets = data
                imgs = imgs.to(device)
                targets = targets.to(device)
                outputs = mynetwork(imgs)

                loss = loss_fn(outputs, targets)
                total_test_loss += loss.item()
                accuracy = (outputs.argmax(1) == targets).sum()
                total_accuracy += accuracy
        print("total loss in test : {} .".format(total_test_loss))
        print("total accuracy in test : {}% ".format(total_accuracy / test_data_size * 100))

        total_test_step += 1
    plot(train_loss)
if __name__ == '__main__':
    # 搭建神经网络
    mynetwork = VGG16().to(device)

    parser = ArgumentParser()
    parser.add_argument('-e', '--maxepoch', help='train max epoch',
                        default=40, type=int)
    parser.add_argument('-b', '--batch_size', help='Training batch size',
                        default=64, type=int)
    args = parser.parse_args()
    train(mynetwork ,args.maxepoch)
    print("---over---")

prueba

import os

import torch
import torchvision
from PIL import Image
from torch import nn
from network.Mynetwork import VGG16

classes = ('airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
# 测试图片
img_path = "../images/horse.jpg"
img = Image.open(img_path)
# 由于png格式的图片格式不是3通道的需要转换成RGB格式
if img_path.endswith(".png"):
    img = img.convert('RGB')
path = r'./result/model/'
transform =torchvision.transforms.Compose([
    torchvision.transforms.Resize((32,32)),
    torchvision.transforms.ToTensor()
])
# 将图片转化成大小是 (32,32)大小的,并转换成tensor张量格式
img = transform(img)
# 获取所有的文件
files = [ file for file in os.listdir(path) if file.endswith('.pth') ]
files.sort(key=lambda x :int((x.split('.')[0]).split('_')[1]) )

# 加载最大的
load_path = path +files[-1]
checkpoint = torch.load(path+files[-1])
# model = torch.load(checkpoint['model'])
model = VGG16()

model.load_state_dict(checkpoint['model'])
# (batch_size,channel,height,width)
img = torch.reshape(img,(1,3,32,32))
model.eval()
with torch.no_grad() :
    output = model(img)
# print(output)
print(classes[output.argmax(1)])

Salida: caballo

todos los códigos

Enlace: https://pan.baidu.com/s/1cAtTvj_8kYjmU-V42cAApg Contraseña: 53dv

posición

  • Debe modificar la ruta y el conjunto de datos modificará la dirección de descarga de CIFAR10 según lo que desee
  • El código se ejecuta en el entorno ubuntu.

Implementar en goolge cloab

  • Dado que necesita usar la tarjeta gráfica para el entrenamiento, echemos un vistazo a la colaboración de Goolge.
  • Si lo usa, puede ejecutarlo a continuación, si no, use lo anterior para ejecutarlo en Pycharm

Enlace: https://pan.baidu.com/s/1u7ZYaFD3b-4Uu4KkQ4tsDA Contraseña: 2eur

Supongo que te gusta

Origin blog.csdn.net/qq_41661809/article/details/124972685
Recomendado
Clasificación