Tabla de contenido
Introducción al conjunto de datos
Enlace de descarga del conjunto de datos
Función de conjunto de datos - leer datos
Método de codificación de etiquetas
Código completo (función CamvidDataset)
Introducción al conjunto de datos
Se utiliza el conjunto de datos de escena de conducción Camvid, que contiene 701 imágenes de segmentación semántica de escena de conducción, que se dividen en conjunto de entrenamiento, conjunto de verificación y conjunto de prueba, con 367, 101 y 233 imágenes respectivamente.
El directorio del conjunto de datos es el siguiente:
Enlace de descarga del conjunto de datos
Enlace: https://pan.baidu.com/s/1HLviQ3AUU7jinWX0YCMWtA?pwd=aaaa
Código de extracción: aaaa
Función de conjunto de datos - leer datos
Pasos de lectura de datos
1. Qué datos leer: salida de índice por muestreador
2. Dónde leer datos: root_dir (ruta) en Dataset
3. Cómo leer datos: la función __getitem__(self,index) en Dataset lee datos de acuerdo con el índice de índice (usted mismo debe escribir la función clave)
introducción de la etiqueta
Interceptar algunas etiquetas en train_labels
Se puede ver que: a diferencia de la etiqueta en la clasificación de imágenes, es una etiqueta específica 0 1 2 ... 11 (la imagen completa representa una categoría ); la etiqueta en la segmentación de imágenes es una imagen de tres canales en color RGB, y diferentes colores representan diferentes categorías ( toda la imagen se divide en diferentes categorías píxel por píxel ), y la tabla de correspondencia entre colores y categorías se muestra en class_dict.csv. (Un total de 12 categorías)
Método de codificación de etiquetas
Lea el archivo class_dict.csv y genere un mapa de colores:
mapa de colores=[[128, 128, 128],[128, 0, 0],[192, 192, 128],[128, 64, 128],[0, 0, 192],[128, 128, 0] [192, 128, 128],[64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192],[0, 0, 0]]
(Un total de 12 categorías, utilice el subíndice colormap.index(a) de los elementos de la lista para indicar la categoría del elemento a)
Lea cualquier etiqueta, cambie su forma de (h,w,3)->(h,w), cada elemento en (h,w) representa la categoría del píxel actual
import numpy as np
from PIL import Image
colormap=[[128, 128, 128],[128, 0, 0], [192, 192, 128],[128, 64, 128],[0, 0, 192],[128, 128, 0],[192, 128, 128],[64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192],[0, 0, 0]]
label_path=r'D:\图像分割\camvid_from_paper\train_labels\0001TP_006690_L.png'
label=Image.open(label_path)
label = np.array(label) # 此时label.shape=(h,w,3)
h, w, _ = label.shape
label = label.tolist() # 将label转化为list,三维列表
# 遍历label中的每一个元素,为RGB三通道颜色,例如[128,0,0]
for i in range(h):
for j in range(w):
label[i][j] = colormap.index(label[i][j]) # colormap中元素的下标0-11作为类别0-11
label = np.array(label,dtype='int64').reshape((h, w)) # reshape为(h,w)
print(label)
Este código se define en el código completo de la función LabelProcessor.cm2label
Código completo (función CamvidDataset)
from PIL import Image
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
import os
import torch
class LabelProcessor:
cls_num = 12
def __init__(self,file_path):
"""
self.colormap 颜色表 [[128,128,128],[128,0,0],[],...,[]] ['r','g','b']
self.names 类别名
"""
self.colormap,self.names=self.read_color_map(file_path)
def read_color_map(self,file_path):
# 读取csv文件
pd_read_color=pd.read_csv(file_path)
colormap=[]
names=[]
for i in range(len(pd_read_color)):
temp=pd_read_color.iloc[i] # DataFrame格式的按行切片
color=[temp['r'],temp['g'],temp['b']]
colormap.append(color)
names.append(temp['name'])
return colormap,names
def cm2label(self,label):
"""将RGB三通道label (h,w,3)转化为 (h,w)大小,每一个值为当前像素点的类别"""
label = np.array(label)
h, w, _ = label.shape
label = label.tolist()
for i in range(h):
for j in range(w):
label[i][j] = self.colormap.index(label[i][j])
label = np.array(label,dtype='int64').reshape((h, w))
return label
class CamvidDataset(Dataset):
def __init__(self,img_dir,label_dir,file_path):
"""
:param img_dir: 图片路径
:param label_dir: 图片对应的label路径
:param file_path: csv文件(colormap)路径
"""
self.img_dir=img_dir
self.label_dir=label_dir
self.imgs=self.read_file(self.img_dir)
self.labels=self.read_file(self.label_dir)
self.label_processor=LabelProcessor(file_path)
# 类别总数与以及类别名
self.cls_num=self.label_processor.cls_num
self.names=self.label_processor.names
def __getitem__(self, index):
"""根据index下标索引对应的img以及label"""
img=self.imgs[index]
label=self.labels[index]
img=Image.open(img)
label=Image.open(label)
img,label=self.img_transform(img,label)
return img,label
def __len__(self):
if len(self.imgs)==0:
raise Exception('Please check your img_dir'.format(self.img_dir))
return len(self.imgs)
def read_file(self,path):
"""生成每个图片路径名的列表,用于getitem中索引"""
file_path=os.listdir(path)
file_path_list=[os.path.join(path,img_name) for img_name in file_path]
file_path_list.sort()
return file_path_list
def img_transform(self,img,label):
"""对图片做transform"""
transform_img=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
img=transform_img(img)
label = self.label_processor.cm2label(label)
label=torch.from_numpy(label) # numpy转化为tensor
return img,label
if __name__=='__main__':
# 路径
root_dir='D:\图像分割\camvid_from_paper'
img_path = os.path.join(root_dir,'train')
label_path = os.path.join(root_dir,'train_labels')
file_path = os.path.join(root_dir,'class_dict.csv')
train_data=CamvidDataset(img_path,label_path,file_path)
train_loader=DataLoader(train_data,batch_size=8,shuffle=True,num_workers=0)
for i,data in enumerate(train_loader):
img_data,label_data=data
print(img_data.shape,type(img_data))
print(label_data.shape,type(label_data))
Resultado de salida:
antorcha.Tamaño([8, 3, 360, 480]) <clase 'antorcha.Tensor'>
antorcha.Tamaño([8, 360, 480]) <clase 'antorcha.Tensor'>
(cada elemento en label_data es un número entre 0-11)