import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms, utils
from torchvision import models
import glob
device = torch.device("cuda:0"if torch.cuda.is_available()else"cpu")
label = pd.read_csv('train.csv')
label = label.set_index('filename')
labels =[int(label.loc[int(i.split('\\')[1].split('.')[0])])for i in images]
images = glob.glob('train/*.jpg')
num_train =int(len(labels)*0.8)classFoodDataset(Dataset):def__init__(self, images, labels, transform):
self.images = images
self.labels = labels
self.transform = transform
def__getitem__(self, index):
img = Image.open(self.images[index]).convert('RGB')
img = self.transform(img)return img, self.labels[index]def__len__(self):returnlen(self.labels)
transform_train=transforms.Compose([
transforms.Resize([256,256]),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])
transform_val=transforms.Compose([
transforms.Resize([256,256]),
transforms.ToTensor(),
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])])
train_dataset = FoodDataset(images[:num_train], labels[:num_train], transform_train)
train_loader = DataLoader(dataset = train_dataset, batch_size=128, shuffle=True)
val_dataset = FoodDataset(images[num_train:], labels[num_train:], transform_val)
val_loader = DataLoader(dataset = val_dataset, batch_size=128, shuffle=False)defshow_batch(images_batch):
batch_size =len(images_batch)
im_size = images_batch.size(2)
grid = utils.make_grid(images_batch)
plt.imshow(grid.numpy().transpose((1,2,0)))
plt.show()
defbuild_model(num_classes):
transfer_model = models.resnet18(pretrained=True)for param in transfer_model.parameters():
param.requires_grad =False# 修改最后一层维数,即 把原来的全连接层 替换成 输出维数为2的全连接层
dim = transfer_model.fc.in_features
transfer_model.fc = nn.Linear(dim, num_classes)return transfer_model
net = build_model(4).to(device)
criterion = nn.CrossEntropyLoss()# optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3)deftrain():
net.train()
batch_num =len(train_loader)
running_loss =0.0for i, data inenumerate(train_loader,start=1):# 将输入传入GPU
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()# 计算误差并显示
running_loss += loss.item()if i %20==0:print('batch:{}/{} loss:{:.3f}'.format(i, batch_num, running_loss /20))
running_loss =0.0#测试函数defvalidate():
net.eval()# !!!!!!!
correct =0
total =0with torch.no_grad():for data in val_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = net(images)
_, predicted = torch.max(outputs.data,1)
total += labels.size(0)
correct +=(predicted == labels).sum().item()print('Accuracy of the network on the test images: %d %%'%(100* correct / total))
n_epoch =10for epoch inrange(n_epoch):print('epoch {}'.format(epoch+1))
train()
validate()
save_path ='params/param_{}.pkl'.format(epoch)
torch.save(net.state_dict(), save_path)'''
epoch 1
batch:20/39 loss:1.348
Accuracy of the network on the test images: 48 %
epoch 2
batch:20/39 loss:1.183
Accuracy of the network on the test images: 55 %
epoch 3
batch:20/39 loss:1.088
Accuracy of the network on the test images: 64 %
epoch 4
batch:20/39 loss:1.005
Accuracy of the network on the test images: 68 %
epoch 5
batch:20/39 loss:0.953
Accuracy of the network on the test images: 71 %
epoch 6
batch:20/39 loss:0.896
Accuracy of the network on the test images: 73 %
epoch 7
batch:20/39 loss:0.840
Accuracy of the network on the test images: 75 %
epoch 8
batch:20/39 loss:0.797
Accuracy of the network on the test images: 77 %
epoch 9
batch:20/39 loss:0.770
Accuracy of the network on the test images: 78 %
epoch 10
batch:20/39 loss:0.729
Accuracy of the network on the test images: 78 %
'''# net.load_state_dict(torch.load(save_path))for epoch inrange(10,20):print('epoch {}'.format(epoch+1))
train()
validate()
save_path ='params/param_{}.pkl'.format(epoch)
torch.save(net.state_dict(), save_path)'''
epoch 11
batch:20/39 loss:0.704
Accuracy of the network on the test images: 80 %
epoch 12
batch:20/39 loss:0.675
Accuracy of the network on the test images: 81 %
epoch 13
batch:20/39 loss:0.666
Accuracy of the network on the test images: 81 %
epoch 14
batch:20/39 loss:0.655
Accuracy of the network on the test images: 82 %
epoch 15
batch:20/39 loss:0.633
Accuracy of the network on the test images: 83 %
epoch 16
batch:20/39 loss:0.608
Accuracy of the network on the test images: 84 %
epoch 17
batch:20/39 loss:0.588
Accuracy of the network on the test images: 84 %
epoch 18
batch:20/39 loss:0.586
Accuracy of the network on the test images: 84 %
epoch 19
batch:20/39 loss:0.575
Accuracy of the network on the test images: 84 %
epoch 20
batch:20/39 loss:0.561
Accuracy of the network on the test images: 85 %
'''
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)# 注意这里把 net.fc 改成了 netfor param in net.parameters():
param.requires_grad =Truefor epoch inrange(20,30):print('epoch {}'.format(epoch+1))
train()
validate()
save_path ='params/param_{}.pkl'.format(epoch)
torch.save(net.state_dict(), save_path)'''
epoch 21
batch:20/39 loss:0.509
Accuracy of the network on the test images: 87 %
epoch 22
batch:20/39 loss:0.467
Accuracy of the network on the test images: 88 %
epoch 23
batch:20/39 loss:0.395
Accuracy of the network on the test images: 88 %
epoch 24
batch:20/39 loss:0.395
Accuracy of the network on the test images: 89 %
epoch 25
batch:20/39 loss:0.366
Accuracy of the network on the test images: 89 %
epoch 26
batch:20/39 loss:0.337
Accuracy of the network on the test images: 90 %
epoch 27
batch:20/39 loss:0.329
Accuracy of the network on the test images: 91 %
epoch 28
batch:20/39 loss:0.293
Accuracy of the network on the test images: 91 %
epoch 29
batch:20/39 loss:0.282
Accuracy of the network on the test images: 91 %
epoch 30
batch:20/39 loss:0.267
Accuracy of the network on the test images: 92 %
'''