【实验】语义分割-数据集

参考

视频李沐-语义分割和数据集【动手学深度学习v2】
笔记李沐视频-笔记
视频笔记

本文主要讲语义分割的经典数据集——VOC2012,的读取。

一句话概括语义分割:在图片中进行像素级的分类

数据集

PASCAL VOC2012数据集介绍
最重要的语义分割数据集之一是 Pascal VOC 2012
这个数据集有自己的格式 – VOC格式,它是一个使用非常广泛的格式(VOC、COCO 都是比较有名的数据集)

VOC 2012 数据集组件

  • ImageSets/Segmentation:该路径下包含用于训练和测试样本的文本文件
  • JPEGImages:该路径下存储着每个实例的输入图像
  • SegmentationClass:该路径下存储着每个实例的标签(此处的标签也采用图像格式,其尺寸和它所标注的输入图像的尺寸相同;标签中颜色相同的像素属于同一个语义类别)

预处理数据

  • 在之前的任务中,使用再缩放图像使其符合模型的输入形状,而在语义分割中,这样做需要将预测的像素类别重新映射回原始尺寸的输入图像,这样的映射可能不够精确,尤其是在不同语义的分割区域
  • 为避免这个问题,将图像裁剪为固定尺寸,而不再是缩放:使用图像增广中的随机裁剪,裁剪输入图像和标签的相同区域

总结

  • 语义分割通过将图像划分为属于不同语义类别的区域,来识别并理解图像中像素级别的内容
  • 由于语义分割的输入图像和标签在像素上一一对应,输入图像会被随机裁剪为固定尺寸而不是缩放

代码

#!/usr/bin/env python
# coding: utf-8

# In[2]:



import os
import torch
import torchvision
from d2l import torch as d2l

#@save
#下载数据集,并解压
d2l.DATA_HUB['voc2012'] = (d2l.DATA_URL + 'VOCtrainval_11-May-2012.tar',
                           '4e443f8a2eca6b1dac8a6c57641b67dd40621a49')

voc_dir = d2l.download_extract('voc2012', 'VOCdevkit/VOC2012')


# In[6]:

# 将所有输入的图片和标签读入内存
def read_voc_images(voc_dir, is_train=True):
#读取所有voc图像并标注
    txt_fname = os.path.join(voc_dir, 'ImageSets', 'Segmentation',
                             'train.txt' if is_train else 'val.txt')
    mode = torchvision.io.image.ImageReadMode.RGB
    with open(txt_fname, 'r') as f:
        images = f.read().split()
    features, labels = [], []
    for i, fname in enumerate(images):
        features.append(torchvision.io.read_image(os.path.join(voc_dir, 'JPEGImages', f'{
      
      fname}.jpg')))
        labels.append(torchvision.io.read_image(os.path.join(voc_dir, 'SegmentationClass', f'{
      
      fname}.png'), mode))
    return features, labels
# torchvision.io.read_image 读取之后通道位于第一个维度, 在display 的时候需要将通道数移动到最后一个维度
train_features, train_labels = read_voc_images(voc_dir, True)


# 绘制前5张输入图像的标签
n = 5
imgs = train_features[:5] + train_labels[:5]
imgs = [img.permute(1, 2, 0) for img in imgs]
d2l.show_images(imgs, 2, n)


# In[9]:
# VOC 不同颜色对应不同的类
VOC_COLORMAP = [[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
                [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
                [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
                [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
                [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
                [0, 64, 128]]

#@save
VOC_CLASSES = [
    'background', 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus',
    'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
    'person', 'potted plant', 'sheep', 'sofa', 'train', 'tv/monitor']


# In[49]:
# # 构建从RGB 到VOC类别索引的映射
def voc_colormap2label():
    colormap2label = torch.zeros(256**3, dtype=torch.long)
    for i, colormap in enumerate(VOC_COLORMAP):
        colormap2label[(colormap[0] * 256 + colormap[1]) * 256 +
                       colormap[2]] = i
    return colormap2label

# 将标签中点的RGB值映射到类别索引
def voc_label_indices(colormap, colormap2label): # 将colormap 是channel * height * width
    colormap = colormap.permute(1, 2, 0).numpy().astype('int32')
    idx = (colormap[:, :, 0] * 256 + colormap[:, :, 1]) * 256 + colormap[:, :, 2]
#     print(idx.shape)
    return colormap2label[idx]
    


# In[50]:


y = voc_label_indices(train_labels[0], voc_colormap2label())
y[105:115, 130:140], VOC_CLASSES[1]


# In[47]:

# 使用图片增广中的随机裁剪,裁剪输入图像和标签的相同区域
def voc_rand_crop(feature, label, height, width):
    """random crop 特征和标签"""
    rect = torchvision.transforms.RandomCrop.get_params(feature, (height, width))
#     print(rect)
    feature = torchvision.transforms.functional.crop(feature, *rect)
    label = torchvision.transforms.functional.crop(label, *rect)
    return feature, label


# In[30]:
#展示随机裁剪
imgs = []
for _ in range(n):
    imgs += voc_rand_crop(train_features[0], train_labels[0], 200, 300)
    
imgs = [img.permute(1, 2, 0) for img in imgs]
d2l.show_images(imgs[::2] + imgs[1::2], 2, n)

#自定义语义分割数据集类
# In[32]:
# 用户自定义dataset class
# 至少需要实现,init, getitem, len
# 图片分割不好用resize,因为对label进行resize 会有歧义。但可以使用crop
class VOCSegDataset(torch.utils.data.Dataset):
    def __init__(self, is_train, crop_size, voc_dir):
        self.transform = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.crop_size = crop_size
        
        features, labels = read_voc_images(voc_dir, is_train=is_train)
        self.features = [self.normalize_image(feature) for feature in self.filter(features)] # 经过filter 和 normalize
        self.labels = self.filter(labels)
        
        self.colormap2label = voc_colormap2label()
        print(f'read {
      
      len(self.features)} examples')
        
        
    def normalize_image(self, img):
        return self.transform(img.float()) #???????/
    
    def filter(self, imgs): # 由于一些图片的大小比crop_size 的图片还要小
        return [img for img in imgs if (img.shape[1] >= self.crop_size[0] and  img.shape[2] >= self.crop_size[1])]
        
    def __getitem__(self, idx):
        feature, label = voc_rand_crop(self.features[idx], self.labels[idx], *self.crop_size)
        return (feature, voc_label_indices(label, self.colormap2label))
        
    def __len__(self):
        return len(self.features)


# In[33]:

# 测试Dataset
crop_size = (320, 480)
voc_train = VOCSegDataset(True, crop_size, voc_dir)
voc_test = VOCSegDataset(False, crop_size, voc_dir)


# In[51]:


batch_size = 64
train_iter = torch.utils.data.DataLoader(
    voc_train, batch_size, shuffle=True, drop_last=True)
#     num_workers=d2l.get_dataloader_workers())
for X, Y in train_iter:
    print(X.shape)
    print(Y.shape)
    break


# In[45]:


# 组合为一个函数
def load_data_voc(batch_size, crop_size):
    """Load the VOC semantic segmentation dataset."""
    voc_dir = d2l.download_extract('voc2012',
                                   os.path.join('VOCdevkit', 'VOC2012'))
    num_workers = d2l.get_dataloader_workers()
    
    train_iter = torch.utils.data.DataLoader(
        VOCSegDataset(True, crop_size, voc_dir), batch_size, shuffle=True,
        drop_last=True, num_workers=num_workers)
    
    test_iter = torch.utils.data.DataLoader(
        VOCSegDataset(False, crop_size, voc_dir), batch_size, drop_last=True,
        num_workers=num_workers)
    return train_iter, test_iter

猜你喜欢

转载自blog.csdn.net/zhe470719/article/details/124901899