pytorch+resnet18实现长尾数据集分类(一)

实验基于论文: Class-Balanced Loss Based on Effective Number of Samples

论文解读:https://blog.csdn.net/weixin_41735859/article/details/105637597

Class-balanced-loss代码地址:https://github.com/vandit15/Class-balanced-loss-pytorch

resnet18代码参考链接:https://blog.csdn.net/sunqiande88/article/details/80100891

制作数据集

论文中通过公式 n = n i u i n = n_iu^i i i 为类索引.制作长尾cifar10数据集.以下代码以不均匀比例100为例.论文作者制作好的数据集,我们也可以通过科学上网点击该谷歌云链接下载.

loadcifar.py

import torch
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from PIL import Image

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict
# 从源文件读取数据
# 返回 train_data[12406,3072]和labels[12406]
#    test_data[10000,3072]和labels[10000]
def get_data(train=False):
    data = None
    labels = None
    new_data = None
    new_labels = []

    if train == True:
        for i in range(1, 6):
            batch = unpickle('data/cifar-10-batches-py/data_batch_' + str(i))
            if i == 1:
                data = batch[b'data']
                labels = batch[b'labels']
            else:
                data = np.concatenate([data, batch[b'data']])
                labels = np.concatenate([labels, batch[b'labels']])

        count = np.zeros((10),dtype=np.int)
        for i in range(len(labels)):
            labels[i] = labels[i].reshape(1,1)
            data[i] = data[i].reshape((1,3072))
            # 设置 n = n_iu^i
            if count[labels[i]] < int(np.floor(5000 * ((1 / 100) ** (1 / 9)) ** (labels[i]))):
                count[labels[i]] += 1
                if i == 0:
                    new_data = data[i]
                else:
                    new_data = np.concatenate([new_data,data[i]])
                new_labels.append(labels[i])
            else:
                continue
        new_labels = np.array(new_labels)
        new_data = new_data.reshape(-1,3072)

    else:
        batch = unpickle('data/cifar-10-batches-py/test_batch')
        new_data = batch[b'data']
        new_labels = batch[b'labels']

    return new_data, new_labels

# 图像预处理函数,Compose会将多个transform操作包在一起
# 对于彩色图像,色彩通道不存在平稳特性
transform = transforms.Compose([
    # ToTensor是指把PIL.Image(RGB) 或者numpy.ndarray(H x W x C)
    # 从0到255的值映射到0到1的范围内,并转化成Tensor格式。
    transforms.ToTensor(),
    # Normalize函数将图像数据归一化到[-1,1]
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 将标签转换为torch.LongTensor
def target_transform(label):
    label = np.array(label)
    target = torch.from_numpy(label).long()
    return target

'''
自定义数据集读取框架来载入cifar10数据集
需要继承data.Dataset
'''
class Cifar10_Dataset(Data.Dataset):
    def __init__(self, train=True, transform=None, target_transform=None):
        # 初始化文件路径
        self.transform = transform
        self.target_transform = target_transform
        self.train = train

        # 载入训练数据集
        if self.train:
            self.train_data, self.train_labels = get_data(train)
            num = self.train_data.shape[0]
            self.train_data = self.train_data.reshape((num, 3, 32, 32))
            # 将图像数据格式转换为[height,width,channels]方便预处理
            self.train_data = self.train_data.transpose((0, 2, 3, 1))
            # 载入测试数据集
        else:
            self.test_data, self.test_labels = get_data()
            self.test_data = self.test_data.reshape((10000, 3, 32, 32))
            self.test_data = self.test_data.transpose((0, 2, 3, 1))
        pass

    def __getitem__(self, index):
        # 从数据集中读取一个数据并对数据进行
        # 预处理返回一个数据对,如(data,label)
        if self.train:
            img, label = self.train_data[index], self.train_labels[index]
        else:
            img, label = self.test_data[index], self.test_labels[index]
        img = Image.fromarray(img)
        # 图像预处理
        if self.transform is not None:
            img = self.transform(img)
        # 标签预处理
        if self.target_transform is not None:
            target = self.target_transform(label)
        return img, target

    def __len__(self):
        # 返回数据集的size
        if self.train:
            return len(self.train_data)
        else:
            return len(self.test_data)

if __name__ == '__main__':
    # 读取训练集和测试集
    train_data = Cifar10_Dataset(True, transform, target_transform)
    print('size of train_data:{}'.format(train_data.__len__()))
    test_data = Cifar10_Dataset(False, transform, target_transform)
    print('size of test_data:{}'.format(test_data.__len__()))

第二步:定义损失函数
第三步:训练

猜你喜欢

转载自blog.csdn.net/weixin_41735859/article/details/105910383