Codificación de etiquetas RGB de tres canales en segmentación de imágenes (código completo basado en la función Dataset del dataset Camvid)

Tabla de contenido

Introducción al conjunto de datos

Enlace de descarga del conjunto de datos

Función de conjunto de datos - leer datos

Pasos de lectura de datos

introducción de la etiqueta

Método de codificación de etiquetas

Código completo (función CamvidDataset)


Introducción al conjunto de datos

Se utiliza el conjunto de datos de escena de conducción Camvid, que contiene 701 imágenes de segmentación semántica de escena de conducción, que se dividen en conjunto de entrenamiento, conjunto de verificación y conjunto de prueba, con 367, 101 y 233 imágenes respectivamente.

El directorio del conjunto de datos es el siguiente:

Enlace de descarga del conjunto de datos

Enlace: https://pan.baidu.com/s/1HLviQ3AUU7jinWX0YCMWtA?pwd=aaaa 
Código de extracción: aaaa

Función de conjunto de datos - leer datos

Pasos de lectura de datos

1. Qué datos leer: salida de índice por muestreador

2. Dónde leer datos: root_dir (ruta) en Dataset

3. Cómo leer datos: la función __getitem__(self,index) en Dataset lee datos de acuerdo con el índice de índice (usted mismo debe escribir la función clave)

introducción de la etiqueta

Interceptar algunas etiquetas en train_labels

Se puede ver que: a diferencia de la etiqueta en la clasificación de imágenes, es una etiqueta específica 0 1 2 ... 11 (la imagen completa representa una categoría ); la etiqueta en la segmentación de imágenes es una imagen de tres canales en color RGB, y diferentes colores representan diferentes categorías ( toda la imagen se divide en diferentes categorías píxel por píxel ), y la tabla de correspondencia entre colores y categorías se muestra en class_dict.csv. (Un total de 12 categorías)

Método de codificación de etiquetas

 Lea el archivo class_dict.csv y genere un mapa de colores:

mapa de colores=[[128, 128, 128],[128, 0, 0],[192, 192, 128],[128, 64, 128],[0, 0, 192],[128, 128, 0] [192, 128, 128],[64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192],[0, 0, 0]]

(Un total de 12 categorías, utilice el subíndice colormap.index(a) de los elementos de la lista para indicar la categoría del elemento a)

Lea cualquier etiqueta, cambie su forma de (h,w,3)->(h,w), cada elemento en (h,w) representa la categoría del píxel actual

import numpy as np
from PIL import Image

colormap=[[128, 128, 128],[128, 0, 0], [192, 192, 128],[128, 64, 128],[0, 0, 192],[128, 128, 0],[192, 128, 128],[64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192],[0, 0, 0]]

label_path=r'D:\图像分割\camvid_from_paper\train_labels\0001TP_006690_L.png'
label=Image.open(label_path)
label = np.array(label)  # 此时label.shape=(h,w,3)
h, w, _ = label.shape
label = label.tolist()  # 将label转化为list,三维列表 

# 遍历label中的每一个元素,为RGB三通道颜色,例如[128,0,0]
for i in range(h):
    for j in range(w):
        label[i][j] = colormap.index(label[i][j])  # colormap中元素的下标0-11作为类别0-11
label = np.array(label,dtype='int64').reshape((h, w))  # reshape为(h,w)
print(label)

Este código se define en el código completo de la función LabelProcessor.cm2label

Código completo (función CamvidDataset)

from PIL import Image
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
import os
import torch


class LabelProcessor:
    cls_num = 12
    def __init__(self,file_path):
        """
        self.colormap 颜色表 [[128,128,128],[128,0,0],[],...,[]]   ['r','g','b']
        self.names 类别名
        """
        self.colormap,self.names=self.read_color_map(file_path)  

    def read_color_map(self,file_path):
        # 读取csv文件
        pd_read_color=pd.read_csv(file_path)
        colormap=[]
        names=[]

        for i in range(len(pd_read_color)):
            temp=pd_read_color.iloc[i]  # DataFrame格式的按行切片
            color=[temp['r'],temp['g'],temp['b']]
            colormap.append(color)
            names.append(temp['name'])
        return colormap,names
    
    def cm2label(self,label):
        """将RGB三通道label (h,w,3)转化为 (h,w)大小,每一个值为当前像素点的类别"""
        label = np.array(label)
        h, w, _ = label.shape
        label = label.tolist()

        for i in range(h):
            for j in range(w):           
                label[i][j] = self.colormap.index(label[i][j])  
        label = np.array(label,dtype='int64').reshape((h, w))
        return label

class CamvidDataset(Dataset):
    def __init__(self,img_dir,label_dir,file_path):
        """
        :param img_dir: 图片路径
        :param label_dir: 图片对应的label路径
        :param file_path: csv文件(colormap)路径
        """
        self.img_dir=img_dir
        self.label_dir=label_dir

        self.imgs=self.read_file(self.img_dir)
        self.labels=self.read_file(self.label_dir)
        
        self.label_processor=LabelProcessor(file_path)
        # 类别总数与以及类别名
        self.cls_num=self.label_processor.cls_num
        self.names=self.label_processor.names

    def __getitem__(self, index):
        """根据index下标索引对应的img以及label"""
        img=self.imgs[index]
        label=self.labels[index]

        img=Image.open(img)
        label=Image.open(label)

        img,label=self.img_transform(img,label)

        return img,label

    def __len__(self):
        if len(self.imgs)==0:
            raise Exception('Please check your img_dir'.format(self.img_dir))
        return len(self.imgs)

    def read_file(self,path):
        """生成每个图片路径名的列表,用于getitem中索引"""
        file_path=os.listdir(path)
        file_path_list=[os.path.join(path,img_name) for img_name in file_path]
        file_path_list.sort()

        return file_path_list

    def img_transform(self,img,label):
        """对图片做transform"""
        transform_img=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        img=transform_img(img)

        label = self.label_processor.cm2label(label)
        label=torch.from_numpy(label)   # numpy转化为tensor

        return img,label


if __name__=='__main__':
    # 路径
    root_dir='D:\图像分割\camvid_from_paper'
    img_path = os.path.join(root_dir,'train')
    label_path = os.path.join(root_dir,'train_labels')
    file_path = os.path.join(root_dir,'class_dict.csv')

    train_data=CamvidDataset(img_path,label_path,file_path)
    train_loader=DataLoader(train_data,batch_size=8,shuffle=True,num_workers=0)

    for i,data in enumerate(train_loader):
        img_data,label_data=data
        print(img_data.shape,type(img_data))
        print(label_data.shape,type(label_data))

 Resultado de salida:

antorcha.Tamaño([8, 3, 360, 480]) <clase 'antorcha.Tensor'>

antorcha.Tamaño([8, 360, 480]) <clase 'antorcha.Tensor'>

(cada elemento en label_data es un número entre 0-11)

 

Supongo que te gusta

Origin blog.csdn.net/m0_63077499/article/details/127365373
Recomendado
Clasificación