天池雪浪制造AI挑战赛(初赛)

第一次参加比赛,记录一下,我是直接使用迁移学习进行分类 采用vgg16

排名不高仅供参考

import pandas as pd
import torch
import numpy as np
from torch.autograd import Variable
import torchvision
from torchvision import transforms, models
import matplotlib.pyplot as plt
import torch.nn.functional as F 
import os
from sklearn import metrics
import sys


system = sys.platform #判断系统的,两个电脑上 路径不一样
if system == 'win32':
    os.chdir('input')
mode = 'train'  # train用来训练, test生成csv提交结果
# mode = 'test'


print('mode = ' + mode)

#这一块是pytorch自带的的载入文件夹图片
transformer = transforms.Compose([
                                  transforms.Resize((224, 224)),
                                  # transforms.CenterCrop(200),
                                  # transforms.RandomVerticalFlip(),
                                  # transforms.RandomHorizontalFlip(),
                                  transforms.ToTensor(),
                                  transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_data = {x: torchvision.datasets.ImageFolder(x, transform=transformer)
              for x in ['train', 'val']}

print(train_data['train'].class_to_idx)
train_loader = {}
train_loader['train'] = torch.utils.data.DataLoader(train_data['train'],
                                               batch_size=10,
                                               shuffle=True)
train_loader['val'] = torch.utils.data.DataLoader(train_data['val'],
                                               batch_size=10,
                                               shuffle=True)

print('train num is ' + str(len(train_data['train'])))
print('val num is ' + str(len(train_data['val'])))

if os.listdir('models'): #恢复模型
    print('restrore the model')
    model = torch.load('my_model.pkl')
else:
    print('use vgg16 model')

    # model = torch.load('vgg16.pkl') #因为网络不好, 我都是提前下下来保存再载入
    # model = torch.load('vgg_11_bn.pkl')
    # models.vgg16_bn(pretrained=True, batch_norm)
  
    model.classifier = torch.nn.Sequential(
        torch.nn.Linear(7*7*512, 2), #vgg提取特征不变  分类层改一下  

if torch.cuda.is_available(): #cpu gpu转换
    model = model.cuda()
print(model)


loss_func = torch.nn.CrossEntropyLoss()
lr = 1e-5

optimizer = torch.optim.Adam(model.parameters(), lr=lr)

## 建立这些列表基本都是用来画图的
epochs = 30 
plot_loss = []
best_auc = 0
auc_list = []
auc_list2 = []
train_acc_list = []
test_acc_list = []
# plt.ion()

def valling(dir_name, model):
    """
    得到网络输出   用来metrics
    0 1标签(用来算正确率)
    概率(算auc)
    label
    """
    model.eval()
    print('valling in ' + str(dir_name))
    y_pre_all = np.array(())
    test_y_all = np.array(())
    all_pro = np.array(())
    for tep_idx, [test_x, test_y] in enumerate(train_loader[dir_name]):
        if tep_idx <= 10:
            test_x, test_y = next(iter(train_loader[dir_name]))
            if torch.cuda.is_available():
                test_x, test_y = (test_x.cuda()), (test_y.cuda())

            y_out_test = model(test_x)

            all_pro = np.append(all_pro, F.softmax(y_out_test, 0).cpu().data.numpy()[:, 1])
            # print(y_out_test)
            y_pre_test = torch.argmax(y_out_test, 1)

            y_pre_test = y_pre_test.cpu().data.numpy()
            test_y = test_y.cpu().data.numpy()

            # print(y_pre_all.shape)
            # print(y_pre_test.shape)
            y_pre_all = np.append(y_pre_all, y_pre_test)
            test_y_all = np.append(test_y_all, test_y)
            # print(y_pre_all.shape)
    return y_pre_all, test_y_all, all_pro


def my_metrics(pre, label, pro):
    '''
    计算auc  acc
    '''
    # print('label shape is ' + str(label.shape))
    # print('pro shape is ' + str(pro.shape))
    auc = metrics.roc_auc_score(label, pro)
    bool_arr_test = (pre == label) 
    test_acc = np.sum(bool_arr_test) / pre.size
    return auc, test_acc


def plot_list(list1, list2, dir_, title):
    '''
    画图  train 和test的acc  auc
    '''
    abs_dir = os.path.abspath(dir_)
    if not os.path.exists(os.path.dirname(abs_dir)):
        os.mkdir(os.path.dirname(abs_dir))
        print('creat dir{}'.format(abs_dir))
    plt.figure()
    plt.plot(list1, label='train')
    plt.plot(list2, label='test')
    plt.title(title)
    plt.legend(loc='best')
    plt.savefig(dir_)
    plt.close()

if mode =='train':
    best_acc = 0
    plot_epoch_loss = []
    # print(model)
    for epoch in range(epochs):
        model.train()
        print('training')
        batch = 0
        epoch_loss = 0
        correct = 0
        # print(train_loader['train'])
        for data in train_loader['train']:
            batch += 1
            x, y = data
            if torch.cuda.is_available():
                x, y = x.cuda(), y.cuda()
            x, y = Variable(x), Variable(y)
            y_out = model(x)
            optimizer.zero_grad()
            loss = loss_func(y_out, y)
            epoch_loss += loss
            # print(loss.data)
            # print(loss.data[0])
            loss.backward()
            optimizer.step()

            a_loss = loss.cpu().data.numpy()
            plot_loss.append(a_loss)
            plt.cla()
            plt.plot(plot_loss)
            print(a_loss)
            plt.text(0, 0.5, 'loss = %.3f' % a_loss, {'color': 'red', 'size': 15})
            plt.savefig('loss2.png')
            plt.close()
            plt.pause(0.5)


        y_pre_all, test_y_all, all_pro = valling('val', model)
        train_y_pre_all, train_test_y_all, train_all_pro = valling('train', model)

        auc, test_acc = my_metrics(y_pre_all, test_y_all, all_pro)
        train_auc, train_test_acc = my_metrics(train_y_pre_all, train_test_y_all, train_all_pro)

        train_acc_list.append(train_test_acc)
        test_acc_list.append(test_acc)

        saved_figs_dir = 'vgg11_full_32' 
        plot_list(train_acc_list, test_acc_list, os.path.join('saved_figs', saved_figs_dir, 'acc.png'), 'acc_curve')

        auc_list.append(auc)
        auc_list2.append(train_auc)
        plot_list(auc_list2, auc_list, os.path.join('saved_figs', saved_figs_dir, 'auc.png'), 'auc_curve')

        best_acc = max(best_acc, test_acc) #保存最好的结果
        best_auc = max(best_auc, auc)

        print('test_acc = ' + str(test_acc * 100)[:4] + '%')
        print('train_acc = ' + str(train_test_acc * 100)[:4] + '%')
        epoch_loss = epoch_loss.cpu().data.numpy()
        print('This ' + str(epoch) + 'th epoch', 'epoch average loss = ' + str(epoch_loss/(batch)))
        plot_epoch_loss.append(epoch_loss / (batch))
        plt.figure()
        plt.plot(plot_epoch_loss)
        plt.title('epoch_loss')
        plt.savefig(os.path.join('saved_figs', saved_figs_dir, 'epoch_loss.png'))
        # plt.savefig('saved_figs/2/epoch_loss.png')
        print('lr = {}'.format(lr))
        if best_acc <= test_acc: #存正确率最高的模型
        # if best_auc <= auc:#存auc最高的
            print('score is better  store model')
            torch.save(model, 'models/my_model.pkl')
        else:
            print("not good don't save")
        print('-' * 40)    


else:
    #用来生成提交结果
    test_data = torchvision.datasets.ImageFolder('test', transform=transformer)
    test_data_loader = torch.utils.data.DataLoader(
                                        test_data, 
                                        batch_size=10,
                                        shuffle=False)
    ret_df = pd.DataFrame(columns=['filename', 'probability'])
 
    filenames = []
    for i in test_data.imgs:
        filename = os.path.basename(i[0])
        filenames.append(filename)

    # print(filenames)
    ret_df['filename'] = filenames
    for i, [x, y] in enumerate(test_data_loader):
        if torch.cuda.is_available():
            x = x.cuda()
        x = Variable(x)

        pre_out = model(x)
        pro = F.softmax(pre_out).cpu().data.numpy()[:, 1]
        pro = np.clip(pro, 0.000001, 0.999999)
        print('The ' + str(i*10) + ' th ' + 'row')

        try:
            ret_df.iloc[10*i: 10*i+10, 1] = pro
        except Exception:
            ret_df.loc[10*i:, 'probability'] = pro

    ret_df = ret_df.round(6)
    print((ret_df['probability'] <= 0).sum())
    print((ret_df['probability'] >= 1).sum())
    ret_df.to_csv('outputs/submission.csv', index=False, encoding='utf-8')





新人学习中

猜你喜欢

转载自blog.csdn.net/qq_22526061/article/details/81735134