(一)pytorch单任务图像分类

深度学习主要由:数据读取、网络模型、损失函数、优化器这四个部分构成

最开始不应该纠结于这些细节,应该先让代码跑起来再去研究代码是怎么写的

下面的代码只是训练部分的代码,并加上验证模型准确率的功能。

1.项目分布:创建一个文件夹my_data1,在my_data1里面创建train和valid这个文件夹

(文件夹名称固定,train和valid不要写错,不然代码跑不起来)

  train是训练集的图片,valid是验证集的图片。在train这个文件夹里面,你训练多少个类别就创建    多少个文件夹(比如我只训练两类就只创建两个文件夹cat和dog,文件夹名称不固定

  valid文件夹  的格式和train的格式一样。

2.代码参数介绍:如果你训练的类别为3,就把代码里面的num_classes=2改成num_classes=3

   代码默认只训练100轮,想训练200轮的话就把代码里面的Epoches=100改成Epoches=200

   代码默认Batch_size为4,设置多少与显卡有关,显卡越好可以设的值就越大。

   代码默认将图片resize成【224,224】再进行训练,想改的话可以对Image_Size进行修改

3.训练代码train.py:里面用到的模型是resnet18,并加载预训练模型进行训练,然后冻结前30层

import torch
from torch.autograd import Variable
import torchvision
from torchvision import datasets, transforms, models
import os
import matplotlib.pyplot as plt
import time
import torch.optim as optim

from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"


def train():
    running_loss = 0
    for batch_idx, (data, target) in enumerate(train_data):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = net(data)
        loss = criterion(out, target)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

    return running_loss


def test():
    correct, total = 0, 0
    with torch.no_grad():
        for _, (data, target) in enumerate(val_data):
            data, target = data.to(device), target.to(device)
            out = net(data)
            prediction = torch.max(out.data, dim=1)[1]
            total += target.size(0)
            correct += (prediction == target).sum().item()
        print('Accuracy on test set: (%d/%d)=%d %%' % (correct, total, 100 * correct / total))


if __name__ == '__main__':
    loss_list = []
    Epoches = 100
    Batch_Size = 4
    Image_Size = [224, 224]

    # 1.数据加载
    data_dir = r'D:\Code\python\完整项目放置\classify_project\multi_classification\my_dataset1'
    # 1.1 定义要对数据进行的处理
    data_transform = {x: transforms.Compose([transforms.Resize(Image_Size), transforms.ToTensor()]) for x in
                      ["train", "valid"]}
    image_datasets = {x: datasets.ImageFolder(root=os.path.join(data_dir, x), transform=data_transform[x]) for x in
                      ["train", "valid"]}
    dataloader = {x: torch.utils.data.DataLoader(dataset=image_datasets[x], batch_size=Batch_Size, shuffle=True) for x in
                  ["train", "valid"]}
    train_data, val_data = dataloader["train"], dataloader["valid"]

    index_classes = image_datasets["train"].class_to_idx
    print(index_classes)
    example_classes = image_datasets["train"].classes
    print(example_classes)

    # 2.数据预览, 在训练的时候可以注释掉
    # X_example, y_example = next(iter(dataloader["train"]))
    # img = torchvision.utils.make_grid(X_example)
    # img = img.numpy().transpose([1, 2, 0])
    # for i in range(len(y_example)):
    #     index = y_example[i]
    #     print(example_classes[index], end='   ')
    #     if (i+1)%8 == 0:
    #         print()
    # plt.imshow(img)
    # plt.show()

    # 3.模型加载, 并对模型进行微调
    net = models.resnet18(pretrained=True)
    fc_features = net.fc.in_features

    # 设置训练的类别个数,我这里只有两类所以写2
    num_classes = 2
    net.fc = torch.nn.Linear(fc_features, num_classes)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 4.pytorch fine tune 微调(冻结一部分层)。这里是冻结网络前30层参数进行训练。
    for i, param in enumerate(net.parameters()):
        if i < 30:
            param.requires_grad = False
    net.to(device)

    # 5.定义损失函数,以及优化器
    LR = 1e-3
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=LR)

    for epoch in range(Epoches):
        loss = train()
        loss_list.append(loss)
        print("第%d轮的loss为:%5f:" % (epoch, loss))
        test()

        # net.state_dict只保存模型的参数
        # torch.save(net.state_dict(), 'Model2.pth')

        # 保存整个模型
        torch.save(net, "my_model.pth")

    plt.title("Graph")
    plt.plot(range(Epoches), loss_list)
    plt.ylabel("loss")
    plt.xlabel("epoch")
    plt.show()

猜你喜欢

转载自blog.csdn.net/m0_48095841/article/details/125660998