0014-pytorch入门-二分类迁移学习实战


import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchvision.models import ResNet
import numpy as np
import matplotlib.pyplot as plt
import os

data_dir = "/hymenoptera_data/"

train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
                                                 transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))

val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
                                               transform=transforms.Compose(
                                                     [
                                                         transforms.RandomResizedCrop(224),
                                                         transforms.RandomHorizontalFlip(),
                                                         transforms.ToTensor(),
                                                         transforms.Normalize(
                                                             mean=(0.485, 0.456, 0.406),
                                                             std=(0.229, 0.224, 0.225))
                                                     ]))


train_dataloader = DataLoader(dataset=train_dataset,batch_size=4,shuffle=4)
val_dataloder = DataLoader(dataset=val_dataset,batch_size=4,shuffle=4)

#类别名称
class_names = train_dataloader.classes
print('class_names:{}'.format(class_names))

#训练设备 cpu/GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print('train_device:{}'.format(device.type))

#随机显示一个batch
plt.figure()
torchvision.utils.make_grid(next(iter(train_dataloader)))
plt.imshow()

# -------------------------模型选择,优化方法, 学习率策略----------------------
model = models.resnet18(pretrained=True)

#全联接层的输入通道in_channels个数
num_fc_in = model.fc.in_features

#改变全联接层,2分类问题,out_features = 2
model.fc = nn.Linear(num_fc_in,2)

#模型迁移到cpu/GPU
model = model.to(device)

#定义损失函数
loss_fc = nn.CrossEntropyLoss()

#选择优化方法
optimizer = optim.SGD(model.parameters(),lr=0.0001,momentum=0.9)

#学习率调整策略
#每个7个epoch调整一次
exp_lr_scheduler = lr_scheduler.StepLR(optimizer==optimizer,step_size=10,gamma=0.5) #step_size

#train
num_epochs = 50
for epoch in range(num_epochs):
    running_loss = 0.0
    exp_lr_scheduler.step()

    for i,sample_batch in enumerate(train_dataloader):
        inputs = sample_batch[0]
        labels = sample_batch[1]

        model.train()

        #GPU /CPU
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        #forward
        outputs = model(inputs)

        #loss
        loss = loss_fc(outputs,labels)

        #loss求导,反向
        loss.backward()

        optimizer.step()

        running_loss += loss.item()

        #test
        if i%20==19:
            correct = 0
            total = 0
            model.eval()
            for images_test,labels_test in val_dataloder:
                images_test = images_test.to(device)
                labels_test = labels_test.to(device)

                outputs_test = model(images_test)
                _, prediction = torch.max(outputs_test, 1)
                correct += (torch.sum((prediction == labels_test))).item()
                # print(prediction, labels_test, correct)
                total += labels_test.size(0)
            print('[{}, {}] running_loss = {:.5f} accurcay = {:.5f}'.format(epoch + 1, i + 1, running_loss / 20,
                                                                            correct / total))
            running_loss = 0.0

        # if i % 10 == 9:
        #     print('[{}, {}] loss={:.5f}'.format(epoch+1, i+1, running_loss / 10))
        #     running_loss = 0.0

    print('training finish !')
    # torch.save(model.state_dict(), './model/model_2.pth')




猜你喜欢

转载自blog.csdn.net/zhonglongshen/article/details/112748424