在图像分类实验中,经常能看到对数据集进行数据增强操作,其中包括transforms.Normalize(),这个函数的定义如下:
torchvision.transforms.Normalize(mean, std, inplace=False)
功能:针对RGB3个 channel 分别对图像进行标准化
output = ( input - mean ) / std
- mean: 各通道的均值
- std: 各通道的标准差
- inplace: 是否原地操作
通常ImageNet有自己的标准化参数,是通过抽样统计图像的均值方差得到的,那么针对本地特定数据集,如何获取到适合的参数呢?我参考了PyTorch数据归一化处理:transforms.Normalize及计算图像数据集的均值和方差_紫芝的博客-CSDN博客_pytorch 数据归一化
原文代码有一处错误,需要先把transform设置为transforms.ToTensor(),而不是None,否则会运行错误。以下是改正后的代码:
def getStat(train_data):
'''
Compute mean and variance for training data
:param train_data: 自定义类Dataset(或ImageFolder即可)
:return: (mean, std)
'''
print('Compute mean and variance for training data.')
print(len(train_data))
train_loader = torch.utils.data.DataLoader(
train_data, batch_size=1, shuffle=False, num_workers=0,
pin_memory=True)
mean = torch.zeros(3)
std = torch.zeros(3)
for X, _ in train_loader:
for d in range(3):
mean[d] += X[:, d, :, :].mean()
std[d] += X[:, d, :, :].std()
mean.div_(len(train_data))
std.div_(len(train_data))
return list(mean.numpy()), list(std.numpy())
if __name__ == '__main__':
train_dataset = ImageFolder(root=r'/data1/sharedata/leafseg/', transform=transforms.ToTensor())
print(getStat(train_dataset))
Compute mean and variance for training data.
3257
([0.059938803, 0.08676067, 0.041085023], [0.10522498, 0.1488454, 0.07508467])
将结果写入transform列表中即可。
data_transforms = { 'train': transforms.Compose([ transforms.Resize(640), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.0599, 0.0868, 0.0411], [0.1052, 0.1488, 0.0751]) ]), 'val': transforms.Compose([ transforms.Resize(640), transforms.ToTensor(), transforms.Normalize([0.0599, 0.0868, 0.0411], [0.1052, 0.1488, 0.0751]) ]), }