画像セグメンテーションにおける RGB 3 チャネル ラベルのエンコーディング (Camvid データセットの Dataset 関数に基づく完全なコード)

目次

データセットの紹介

データセットのダウンロード リンク

データセット関数 - データの読み取り

データ読み取り手順

レーベル紹介

ラベルのエンコード方式

完全なコード (CamvidDataset 関数)


データセットの紹介

Camvid 運転シーン データセットが使用されます。これには 701 個の運転シーン セマンティック セグメンテーション画像が含まれており、トレーニング セット、検証セット、テスト セットに分割され、それぞれ 367、101、233 個の画像が含まれます。

データセットのディレクトリは次のとおりです。

データセットのダウンロード リンク

リンク: https://pan.baidu.com/s/1HLviQ3AUU7jinWX0YCMWtA?pwd=aaaa 
抽出コード: aaaa

データセット関数 - データの読み取り

データ読み取り手順

1. 読み込むデータ:サンプラーが出力したインデックス

2. データを読み取る場所: データセットの root_dir (パス)

3. データの読み取り方法: Dataset の __getitem__(self,index) 関数はインデックスインデックスに従ってデータを読み取ります (key 関数を自分で記述する必要があります)

レーベル紹介

train_labels のいくつかのラベルをインターセプトします

画像分類におけるラベルとは異なり、特定のラベル 0 1 2 ... 11 (画像全体がカテゴリを表す) であること、画像分割におけるラベルはカラー RGB 3 チャネル画像であること、および異なる色は異なるカテゴリを表し (画像全体がピクセルごとに異なるカテゴリに分割されます)、色とカテゴリの対応表は class_dict.csv に示されています。(全12カテゴリー)

ラベルのエンコード方式

 class_dict.csv ファイルを読み取り、カラーマップを生成します。

カラーマップ=[[128, 128, 128],[128, 0, 0],[192, 192, 128],[128, 64, 128],[0, 0, 192],[128, 128, 0] [192, 128, 128],[64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192],[0, 0, 0]]

(合計 12 のカテゴリ。リスト内の要素の添字 colormap.index(a) を使用して、要素 a のカテゴリを示します)

任意のラベルを読み取り、その形状を (h,w,3)->(h,w) に変更します。(h,w) の各要素は現在のピクセルのカテゴリを表します

import numpy as np
from PIL import Image

colormap=[[128, 128, 128],[128, 0, 0], [192, 192, 128],[128, 64, 128],[0, 0, 192],[128, 128, 0],[192, 128, 128],[64, 64, 128],[64, 0, 128],[64, 64, 0],[0, 128, 192],[0, 0, 0]]

label_path=r'D:\图像分割\camvid_from_paper\train_labels\0001TP_006690_L.png'
label=Image.open(label_path)
label = np.array(label)  # 此时label.shape=(h,w,3)
h, w, _ = label.shape
label = label.tolist()  # 将label转化为list,三维列表 

# 遍历label中的每一个元素,为RGB三通道颜色,例如[128,0,0]
for i in range(h):
    for j in range(w):
        label[i][j] = colormap.index(label[i][j])  # colormap中元素的下标0-11作为类别0-11
label = np.array(label,dtype='int64').reshape((h, w))  # reshape为(h,w)
print(label)

このコードは、完全なコード LabelProcessor.cm2label 関数で定義されています。

完全なコード (CamvidDataset 関数)

from PIL import Image
from torch.utils.data import Dataset,DataLoader
import pandas as pd
import numpy as np
import torchvision.transforms as transforms
import os
import torch


class LabelProcessor:
    cls_num = 12
    def __init__(self,file_path):
        """
        self.colormap 颜色表 [[128,128,128],[128,0,0],[],...,[]]   ['r','g','b']
        self.names 类别名
        """
        self.colormap,self.names=self.read_color_map(file_path)  

    def read_color_map(self,file_path):
        # 读取csv文件
        pd_read_color=pd.read_csv(file_path)
        colormap=[]
        names=[]

        for i in range(len(pd_read_color)):
            temp=pd_read_color.iloc[i]  # DataFrame格式的按行切片
            color=[temp['r'],temp['g'],temp['b']]
            colormap.append(color)
            names.append(temp['name'])
        return colormap,names
    
    def cm2label(self,label):
        """将RGB三通道label (h,w,3)转化为 (h,w)大小,每一个值为当前像素点的类别"""
        label = np.array(label)
        h, w, _ = label.shape
        label = label.tolist()

        for i in range(h):
            for j in range(w):           
                label[i][j] = self.colormap.index(label[i][j])  
        label = np.array(label,dtype='int64').reshape((h, w))
        return label

class CamvidDataset(Dataset):
    def __init__(self,img_dir,label_dir,file_path):
        """
        :param img_dir: 图片路径
        :param label_dir: 图片对应的label路径
        :param file_path: csv文件(colormap)路径
        """
        self.img_dir=img_dir
        self.label_dir=label_dir

        self.imgs=self.read_file(self.img_dir)
        self.labels=self.read_file(self.label_dir)
        
        self.label_processor=LabelProcessor(file_path)
        # 类别总数与以及类别名
        self.cls_num=self.label_processor.cls_num
        self.names=self.label_processor.names

    def __getitem__(self, index):
        """根据index下标索引对应的img以及label"""
        img=self.imgs[index]
        label=self.labels[index]

        img=Image.open(img)
        label=Image.open(label)

        img,label=self.img_transform(img,label)

        return img,label

    def __len__(self):
        if len(self.imgs)==0:
            raise Exception('Please check your img_dir'.format(self.img_dir))
        return len(self.imgs)

    def read_file(self,path):
        """生成每个图片路径名的列表,用于getitem中索引"""
        file_path=os.listdir(path)
        file_path_list=[os.path.join(path,img_name) for img_name in file_path]
        file_path_list.sort()

        return file_path_list

    def img_transform(self,img,label):
        """对图片做transform"""
        transform_img=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
        img=transform_img(img)

        label = self.label_processor.cm2label(label)
        label=torch.from_numpy(label)   # numpy转化为tensor

        return img,label


if __name__=='__main__':
    # 路径
    root_dir='D:\图像分割\camvid_from_paper'
    img_path = os.path.join(root_dir,'train')
    label_path = os.path.join(root_dir,'train_labels')
    file_path = os.path.join(root_dir,'class_dict.csv')

    train_data=CamvidDataset(img_path,label_path,file_path)
    train_loader=DataLoader(train_data,batch_size=8,shuffle=True,num_workers=0)

    for i,data in enumerate(train_loader):
        img_data,label_data=data
        print(img_data.shape,type(img_data))
        print(label_data.shape,type(label_data))

 出力結果:

torch.Size([8, 3, 360, 480]) <クラス 'torch.Tensor'>

torch.Size([8, 360, 480]) <クラス 'torch.Tensor'>

(label_data の各要素は 0 ~ 11 の数字です)

 

おすすめ

転載: blog.csdn.net/m0_63077499/article/details/127365373