# -*- coding: utf-8 -*-
# @Time : 2021/1/24 14:34
# @Author : Johnson
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
import torchvision
from torchvision.transforms import transforms
from torchvision import models
from torchvision.models import ResNet
import numpy as np
import matplotlib.pyplot as plt
import os
data_dir = 'C:/datafolder/hymenoptera_data/'
train_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'train'),
transform=transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
]))
val_dataset = torchvision.datasets.ImageFolder(root=os.path.join(data_dir, 'val'),
transform=transforms.Compose(
[
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225))
]))
train_dataloader = DataLoader(dataset=train_dataset, batch_size=4, shuffle=4)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=4)
#类别名称
class_names = train_dataset.classes
print('class_names:{}'.format(class_names))
#训练设备 CPU/GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("train_device:{}".format(device.type))
# 随机显示一个batch
#plt.figure()
#utils.imshow(next(iter(train_dataloader)))
#plt.show()
# -------------------------模型选择,优化方法, 学习率策略----------------------
model = models.resnet18(pretrained=True)
#全连接层的输入通道in_channels个数
num_fc_in = model.fc.in_features
#改变全连接层,2分类问题,out_features=2
model.fc = nn.Linear(num_fc_in,2)
#模型迁移到CPU/GPU
model = model.to(device)
#定义损失函数
loss_fc = nn.CrossEntropyLoss()
#选择性优化方法
optimizer = optim.SGD(model.parameters(),lr=0.0001,momentum=0.9)
#学习率调整策略
#每7个epoch调整一次
exp_lr_scheduler = lr_scheduler.StepLR(optimizer=optimizer,step_size=10,gamma=0.5) #step_size
# ----------------训练过程-----------------
num_epochs = 50
for epoch in range(num_epochs):
running_loss = 0.0
exp_lr_scheduler.step()
for i, sample_batch in enumerate(train_dataloader):
inputs = sample_batch[0]
labels = sample_batch[1]
model.train()
# GPU/CPU
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
# foward
outputs = model(inputs)
# loss
loss = loss_fc(outputs, labels)
# loss求导,反向
loss.backward()
# 优化
optimizer.step()
#
running_loss += loss.item()
# 測試
if i % 20 == 19:
correct = 0
total = 0
model.eval()
for images_test, labels_test in val_dataloader:
images_test = images_test.to(device)
labels_test = labels_test.to(device)
outputs_test = model(images_test)
_, prediction = torch.max(outputs_test, 1)
correct += (torch.sum((prediction == labels_test))).item()
# print(prediction, labels_test, correct)
total += labels_test.size(0)
print('[{}, {}] running_loss = {:.5f} accurcay = {:.5f}'.format(epoch + 1, i + 1, running_loss / 20,
correct / total))
running_loss = 0.0
# if i % 10 == 9:
# print('[{}, {}] loss={:.5f}'.format(epoch+1, i+1, running_loss / 10))
# running_loss = 0.0
print('training finish !')
# torch.save(model.state_dict(), './model/model_2.pth')
0027-pytorch入门-利用resnet18 finetune进行图片二分类
猜你喜欢
转载自blog.csdn.net/zhonglongshen/article/details/113091967
今日推荐
周排行