训练模型并保存
import torch
import torch. nn as nn
import torch. optim as optim
from torchvision import datasets, transforms, models
from torch. utils. data import Dataset
import sys
transform = transforms. Compose( [
transforms. RandomResizedCrop( 224 ) ,
transforms. RandomRotation( 20 ) ,
transforms. RandomHorizontalFlip( p= 0.5 ) ,
transforms. ToTensor( )
] )
root = "image"
train_dataset = datasets. ImageFolder( root + "/train" , transform)
test_dataset = datasets. ImageFolder( root + "/test" , transform)
train_loader = torch. utils. data. DataLoader( train_dataset, batch_size= 8 , shuffle= True )
test_loader = torch. utils. data. DataLoader( test_dataset, batch_size= 8 , shuffle= True )
classes = train_dataset. classes
classes_index = train_dataset. class_to_idx
print ( classes)
print ( classes_index)
model = models. vgg16( pretrained= True )
print ( model)
for param in model. parameters( ) :
param. requires_grad = False
model. classifier = torch. nn. Sequential( torch. nn. Linear( 25088 , 100 ) ,
torch. nn. ReLU( ) ,
torch. nn. Dropout( p= 0.5 ) ,
torch. nn. Linear( 100 , 2 ) )
LR = 0.0003
entropy_loss = nn. CrossEntropyLoss( )
optimizer = optim. Adam( model. parameters( ) , LR)
def train ( ) :
model. train( )
for i, data in enumerate ( train_loader) :
inputs, labels = data
out = model( inputs)
loss = entropy_loss( out, labels)
optimizer. zero_grad( )
loss. backward( )
optimizer. step( )
def test ( ) :
model. eval ( )
correct = 0
for i, data in enumerate ( test_loader) :
inputs, labels = data
out = model( inputs)
_, predicted = torch. max ( out, 1 )
correct += ( predicted == labels) . sum ( )
print ( "test acc:{0}" . format ( correct. item( ) / len ( test_dataset) ) )
correct = 0
for i, data in enumerate ( train_loader) :
inputs, labels = data
out = model( inputs)
_, predicted = torch. max ( out, 1 )
correct += ( predicted == labels) . sum ( )
print ( "train acc:{0}" . format ( correct. item( ) / len ( train_dataset) ) )
for epoch in range ( 5 ) :
print ( "epoch:" , epoch)
train( )
test( )
torch. save( model. state_dict( ) , "cat_dog.pth" )
加载模型进行预测
import torch
import numpy as np
from PIL import Image
from torchvision import transforms, models
model = models. vgg16( pretrained= True )
model. classifier = torch. nn. Sequential( torch. nn. Linear( 25088 , 100 ) ,
torch. nn. ReLU( ) ,
torch. nn. Dropout( p= 0.5 ) ,
torch. nn. Linear( 100 , 2 ) )
model. load_state_dict( torch. load( "cat_dog.pth" ) )
model. eval ( )
label = np. array( [ "cat" , "dog" ] )
transform = transforms. Compose( [
transforms. Resize( 224 ) ,
transforms. ToTensor( )
] )
def predict ( image_path) :
img = Image. open ( image_path)
img = transform( img) . unsqueeze( 0 )
outputs = model( img)
_, predicted = torch. max ( outputs, 1 )
print ( label[ predicted. item( ) ] )
predict( "image/test/cat/cat.1490.jpg" )