pytorch 迁移学习多分类(resnet18)

import pandas as pd
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from torchvision import transforms, utils
from torchvision import models
import glob

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

label = pd.read_csv('train.csv')
label = label.set_index('filename')
labels = [int(label.loc[int(i.split('\\')[1].split('.')[0])]) for i in images]
images = glob.glob('train/*.jpg')

num_train = int(len(labels)*0.8)

class FoodDataset(Dataset):
    def __init__(self, images, labels, transform):
        self.images = images
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        img = self.transform(img)
        return img, self.labels[index]

    def __len__(self):
        return len(self.labels)

transform_train=transforms.Compose([
    transforms.Resize([256,256]),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])

transform_val=transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])

train_dataset = FoodDataset(images[:num_train], labels[:num_train], transform_train)
train_loader = DataLoader(dataset = train_dataset, batch_size=128, shuffle=True)
val_dataset = FoodDataset(images[num_train:], labels[num_train:], transform_val)
val_loader = DataLoader(dataset = val_dataset, batch_size=128, shuffle=False)

def show_batch(images_batch):
    batch_size = len(images_batch)
    im_size = images_batch.size(2)
    grid = utils.make_grid(images_batch)
    plt.imshow(grid.numpy().transpose((1, 2, 0)))
    plt.show()

在这里插入图片描述


def build_model(num_classes):
    transfer_model = models.resnet18(pretrained=True)
    for param in transfer_model.parameters():
        param.requires_grad = False

    # 修改最后一层维数,即 把原来的全连接层 替换成 输出维数为2的全连接层
    dim = transfer_model.fc.in_features
    transfer_model.fc = nn.Linear(dim, num_classes)

    return transfer_model

net = build_model(4).to(device)
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
optimizer = torch.optim.SGD(net.fc.parameters(), lr=1e-3)


def train():
    net.train() 
    batch_num = len(train_loader)
    running_loss = 0.0
    for i, data in enumerate(train_loader,start=1):
        # 将输入传入GPU
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # 计算误差并显示
        running_loss += loss.item()
        if i % 20 == 0: 
            print(
                'batch:{}/{} loss:{:.3f}'.format(i, batch_num, running_loss / 20))
            running_loss = 0.0


#测试函数
def validate():
    net.eval() # !!!!!!!
    correct = 0
    total = 0
    with torch.no_grad():
        for data in val_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the test images: %d %%' %
          (100 * correct / total))

n_epoch = 10
for epoch in range(n_epoch):
    print('epoch {}'.format(epoch+1))
    train()
    validate()
    save_path = 'params/param_{}.pkl'.format(epoch)
    torch.save(net.state_dict(), save_path)
'''
epoch 1
batch:20/39 loss:1.348
Accuracy of the network on the test images: 48 %
epoch 2
batch:20/39 loss:1.183
Accuracy of the network on the test images: 55 %
epoch 3
batch:20/39 loss:1.088
Accuracy of the network on the test images: 64 %
epoch 4
batch:20/39 loss:1.005
Accuracy of the network on the test images: 68 %
epoch 5
batch:20/39 loss:0.953
Accuracy of the network on the test images: 71 %
epoch 6
batch:20/39 loss:0.896
Accuracy of the network on the test images: 73 %
epoch 7
batch:20/39 loss:0.840
Accuracy of the network on the test images: 75 %
epoch 8
batch:20/39 loss:0.797
Accuracy of the network on the test images: 77 %
epoch 9
batch:20/39 loss:0.770
Accuracy of the network on the test images: 78 %
epoch 10
batch:20/39 loss:0.729
Accuracy of the network on the test images: 78 %
'''
# net.load_state_dict(torch.load(save_path))

for epoch in range(10,20):
    print('epoch {}'.format(epoch+1))
    train()
    validate()
    save_path = 'params/param_{}.pkl'.format(epoch)
    torch.save(net.state_dict(), save_path)

'''
epoch 11
batch:20/39 loss:0.704
Accuracy of the network on the test images: 80 %
epoch 12
batch:20/39 loss:0.675
Accuracy of the network on the test images: 81 %
epoch 13
batch:20/39 loss:0.666
Accuracy of the network on the test images: 81 %
epoch 14
batch:20/39 loss:0.655
Accuracy of the network on the test images: 82 %
epoch 15
batch:20/39 loss:0.633
Accuracy of the network on the test images: 83 %
epoch 16
batch:20/39 loss:0.608
Accuracy of the network on the test images: 84 %
epoch 17
batch:20/39 loss:0.588
Accuracy of the network on the test images: 84 %
epoch 18
batch:20/39 loss:0.586
Accuracy of the network on the test images: 84 %
epoch 19
batch:20/39 loss:0.575
Accuracy of the network on the test images: 84 %
epoch 20
batch:20/39 loss:0.561
Accuracy of the network on the test images: 85 %
'''
optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)  # 注意这里把 net.fc 改成了 net

for param in net.parameters():
    param.requires_grad = True

for epoch in range(20,30):
    print('epoch {}'.format(epoch+1))
    train()
    validate()
    save_path = 'params/param_{}.pkl'.format(epoch)
    torch.save(net.state_dict(), save_path)
'''
epoch 21
batch:20/39 loss:0.509
Accuracy of the network on the test images: 87 %
epoch 22
batch:20/39 loss:0.467
Accuracy of the network on the test images: 88 %
epoch 23
batch:20/39 loss:0.395
Accuracy of the network on the test images: 88 %
epoch 24
batch:20/39 loss:0.395
Accuracy of the network on the test images: 89 %
epoch 25
batch:20/39 loss:0.366
Accuracy of the network on the test images: 89 %
epoch 26
batch:20/39 loss:0.337
Accuracy of the network on the test images: 90 %
epoch 27
batch:20/39 loss:0.329
Accuracy of the network on the test images: 91 %
epoch 28
batch:20/39 loss:0.293
Accuracy of the network on the test images: 91 %
epoch 29
batch:20/39 loss:0.282
Accuracy of the network on the test images: 91 %
epoch 30
batch:20/39 loss:0.267
Accuracy of the network on the test images: 92 %
'''
class TestDataset(Dataset):
    def __init__(self, images, transform):
        self.images = images
        self.transform = transform

    def __getitem__(self, index):
        img = Image.open(self.images[index]).convert('RGB')
        img = self.transform(img)
        return img, self.images[index]

    def __len__(self):
        return len(self.images)

transform_test=transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
])

test_images = glob.glob('test/*.jpg')
test_dataset = TestDataset(test_images, transform_train)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

def test():
    result = {}
    net.eval() # !!!!!!!
    with torch.no_grad():
        for images, names in test_loader:
            images = images.to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            for name, pred in zip(names,predicted.to('cpu')):
                result[name] = pred.item()
    return result
            
result = test() 
keys, values = [], []
for key, value in result.items():
    keys.append(int(key.split('\\')[1].split('.')[0]))
    values.append(value)

df = pd.DataFrame({'filename':keys,'label':values})
df = df.sort_values(by='filename')
df = df.set_index('filename')
df.to_csv('test.csv',header=False,encoding = "UTF8")
发布了274 篇原创文章 · 获赞 446 · 访问量 42万+

猜你喜欢

转载自blog.csdn.net/itnerd/article/details/104306781