セマンティックセグメンテーションシリーズ 3-SegNet (pytorch実装)

SegNet の原稿は 2015 年 12 月に初めて提出され、FCN と同じ時期に属します。FCN より少し遅れて後発に属するため、FCN と同様にセマンティック セグメンテーション ネットワークにも属しており、SegNet の論文では FCN ネットワークとの比較が数多く行われています。

SegNet: 画像セグメンテーションのためのディープ畳み込みエンコーダ/デコーダ アーキテクチャ》 


目次

セグネット

デザインの動機

ネットワーク構造

プールインデックス

結果

モデルの再現

データセットの構築

データセットクラス

データセットとデータローダーを作成する

モデル構築

モデルトレーニング

要約する


セグネット

デザインの動機

著者は、FCN ネットワークのセグメンテーションの結果は有望であると考えていますが、プーリングとダウンサンプリングのプロセスにより特徴マップの解像度が低下し、一部の情報が失われ、結果が粗くなると考えています。したがって、著者らは、ピクセル レベルの分類を改善するために、低解像度の特徴を入力解像度にマッピングするように SegNet を設計しました。

第二に、当時、FCN ネットワークは比較的大規模なモデルと考えられており、符号化層のパラメータは 1 億 3400 万個ですが、復号化層のパラメータは 0.5M しかありませんでした。模型は電車にするには大きすぎました。

したがって、著者は、エンコーダ ネットワーク内の各エンコーダがデコーダ ネットワーク内の SegNet に段階的に接続されるエンドツーエンドを設計しました。アイデアは非常にシンプルで、複数のスケールで抽出された特徴とグローバル コンテキスト情報を保存し、アップサンプリングに利用できる情報をさらに提供することで、より多くの高周波の詳細を保持し、細かいセグメンテーションを実現するというものです。

ネットワーク構造

図1 SegNetのネットワーク構造

前述したように、SegNet はエンコーダ/デコーダ ネットワーク構造を使用しており、各エンコーダ層はデコーダ層に対応し、最後の層はピクセル分類用のソフトマックス分類器です。

このうち、エンコーダ ネットワークは VGG16 の最初の 13 層で構成されており、完全に接続された VGG16 の最後の 3 層がたまたま削除されています。SegNet はトレーニングされた VGG16 のネットワーク パラメーターで初期化できるため、これはより便利です。同時に著者は、復号層のパラメータ量がわずか14.7Mであり、134M FCNと比較してパラメータ量がわずか10分の1であることにも言及しました。

エンコード層のアーキテクチャは、VGG16 の最初の 13 層で比較的単純で、畳み込み、バッチ正規化、ReLu アクティベーションの一連の演算を重ね合わせて特徴を抽出し、コア 2 とステップ サイズの MaxPool を使用します。 2 を使用してダウンサンプリングし、画像の入力平行移動不変性を実現します。ただし、このプーリングとダウンサンプリング操作により特徴マップの解像度が失われるため、レイヤー数が深くなると特徴マップの解像度が低下し、後で元の画像の精細さを復元することが困難になります。アップサンプリング。したがって、作者はエンコーダモジュールでいくつかの作業を行いました。

プールインデックス

ダウンサンプリング プロセス中にいくつかの重要な情報を保存するために、著者らは、プーリング層インデックス (図 1 のプーリング インデックス) を保存して、エンコーダの特徴マップに境界情報を取得して保存する方法を提案しています。これはFCNやUnetのスキップ接続とは異なり、一つは同じ次元の符号化層と復号層の特徴マップを重ね合わせることであり、もう一つは対応する次元のプーリング層インデックスを保存することで画像再構成を助けることである。

SegNet はアップサンプリングの動作が FCN とは異なります。SegNet は、予約されたプーリング インデックスに従って機能をマップします。このステップでは学習は必要なく、その後、トレーニング可能なデコード フィルター (実際にはいくつかの畳み込み層) が続きます。FCN は、デコンボリューション (デコンボリューション) 操作を通じて実装されます。

SegNet のアップサンプリング プロセス中に、特徴はプーリング インデックスによってマッピングされ、畳み込みのためにトレーニング可能なマルチチャネル デコード フィルターに入力され、そのスパースな特徴が強化されます。

図 2 SegNet アップサンプリングと FCN アップサンプリングのプロセス

結果

図 3 CamVid データセットに対する SegNet の効果


モデルの再現

この記事では、CamVid データセット上で SegNet モデルを再現します。 

データセットの構築

まず、いくつかの厄介なライブラリをインポートします。

# 导入库
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import optim
from torch.utils.data import Dataset, DataLoader, random_split
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")
import os.path as osp
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

データセットクラス

Camvid には 32 のクラスがあります。ここでのデータ拡張では、pip を通じてインストールできる albumentations ライブラリを使用します。その理由は、pytorch ライブラリがラベルと画像の同時強化を常に実現できるとは限らないためであり、これは少し奇妙です。画像とラベルは両方とも [448,448] に均一にスケールされます。

torch.manual_seed(17)
# 自定义数据集CamVidDataset
class CamVidDataset(torch.utils.data.Dataset):
    """CamVid Dataset. Read images, apply augmentation and preprocessing transformations.
    
    Args:
        images_dir (str): path to images folder
        masks_dir (str): path to segmentation masks folder
        class_values (list): values of classes to extract from segmentation mask
        augmentation (albumentations.Compose): data transfromation pipeline 
            (e.g. flip, scale, etc.)
        preprocessing (albumentations.Compose): data preprocessing 
            (e.g. noralization, shape manipulation, etc.)
    """
    
    def __init__(self, images_dir, masks_dir):
        self.transform = A.Compose([
            A.Resize(448, 448),
            A.HorizontalFlip(),
            A.VerticalFlip(),
            A.Normalize(),
            ToTensorV2(),
        ]) 
        self.ids = os.listdir(images_dir)
        self.images_fps = [os.path.join(images_dir, image_id) for image_id in self.ids]
        self.masks_fps = [os.path.join(masks_dir, image_id) for image_id in self.ids]

    
    def __getitem__(self, i):
        # read data
        image = np.array(Image.open(self.images_fps[i]).convert('RGB'))
        mask = np.array( Image.open(self.masks_fps[i]).convert('RGB'))
        image = self.transform(image=image,mask=mask)
        
        return image['image'], image['mask'][:,:,0]
        
    def __len__(self):
        return len(self.ids)
    
    
# 设置数据集路径
DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    

train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

データセットとデータローダーを作成する

# 设置数据集路径
DATA_DIR = r'dataset\camvid' # 根据自己的路径来设置
x_train_dir = os.path.join(DATA_DIR, 'train_images')
y_train_dir = os.path.join(DATA_DIR, 'train_labels')
x_valid_dir = os.path.join(DATA_DIR, 'valid_images')
y_valid_dir = os.path.join(DATA_DIR, 'valid_labels')
    

train_dataset = CamVidDataset(
    x_train_dir, 
    y_train_dir, 
)
val_dataset = CamVidDataset(
    x_valid_dir, 
    y_valid_dir, 
)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

データ拡張の結果を確認できます

for index, (img, label) in enumerate(train_loader):
    print(img.shape)
    print(label.shape)
    
    plt.figure(figsize=(10,10))
    plt.subplot(221)
    plt.imshow((img[0,:,:,:].moveaxis(0,2)))
    plt.subplot(222)
    plt.imshow(label[0,:,:])
    
    plt.subplot(223)
    plt.imshow((img[6,:,:,:].moveaxis(0,2)))
    plt.subplot(224)
    plt.imshow(label[6,:,:])
    
    plt.show() 
    if index==0:
        break

 (画像補正でNormalizeを行うと画像の色が少しおかしくなります) しかし、少なくともデータとラベルの同時補正の結果は得られました。

モデル構築

便宜上、モデルはエンコーダと SegNet の 2 つの部分に分割されています。

#Encoder模块

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        #前13层是VGG16的前13层,分为5个stage
        #因为在下采样时要保存最大池化层的索引, 方便起见, 池化层不写在stage中
        self.stage_1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        
        self.stage_2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        
        self.stage_3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )     
        
        self.stage_4 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )   
        
        self.stage_5 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )     
        
    def forward(self, x):
        #用来保存各层的池化索引
        pool_indices = []
        x = x.float()
        
        x = self.stage_1(x)
        #pool_indice_1保留了第一个池化层的索引
        x, pool_indice_1 = nn.MaxPool2d( 2, stride=2, return_indices=True)(x)
        pool_indices.append(pool_indice_1)
        
        x = self.stage_2(x)
        x, pool_indice_2 = nn.MaxPool2d(2, stride=2, return_indices=True)(x)
        pool_indices.append(pool_indice_2)
        
        x = self.stage_3(x)
        x, pool_indice_3 = nn.MaxPool2d(2, stride=2, return_indices=True)(x)
        pool_indices.append(pool_indice_3)   
        
        x = self.stage_4(x)
        x, pool_indice_4 = nn.MaxPool2d(2, stride=2, return_indices=True)(x)
        pool_indices.append(pool_indice_4)
        
        x = self.stage_5(x)
        x, pool_indice_5 = nn.MaxPool2d(2, stride=2, return_indices=True)(x)
        pool_indices.append(pool_indice_5)
        
        return x, pool_indices
    
    
#SegNet网络, Encoder-Decoder
class SegNet(nn.Module):
    def __init__(self, num_classes):
        super(SegNet, self).__init__()
        #加载Encoder
        self.encoder = Encoder()
       #上采样 从下往上, 1->2->3->4->5
        self.upsample_1 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        
        self.upsample_2 = nn.Sequential(
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        
        self.upsample_3 = nn.Sequential(
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        
        self.upsample_4 = nn.Sequential(
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        
        self.upsample_5 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, num_classes, kernel_size=3, stride=1, padding=1),
        )   
        
    def forward(self, x):
        x, pool_indices = self.encoder(x)
        
        #池化索引上采样
        x = nn.MaxUnpool2d(2, 2, padding=0)(x, pool_indices[4])
        x = self.upsample_1(x)
        
        x = nn.MaxUnpool2d(2, 2, padding=0)(x, pool_indices[3])
        x = self.upsample_2(x) 
        
        x = nn.MaxUnpool2d(2, 2, padding=0)(x, pool_indices[2])
        x = self.upsample_3(x)
        
        x = nn.MaxUnpool2d(2, 2, padding=0)(x, pool_indices[1])
        x = self.upsample_4(x)
        
        x = nn.MaxUnpool2d(2, 2, padding=0)(x, pool_indices[0])
        x = self.upsample_5(x)
        
        return x

モデルトレーニング

#载入预训练权重, 500M还挺大的 下载地址:https://download.pytorch.org/models/vgg16_bn-6c64b313.pth
model = SegNet(32+1).cuda()
model.load_state_dict(torch.load(r"checkpoints/vgg16_bn-6c64b313.pth"),strict=False)

from d2l import torch as d2l
#损失函数选用多分类交叉熵损失函数
lossf = nn.CrossEntropyLoss()
#选用adam优化器来训练
optimizer = optim.SGD(model.parameters(),lr=0.1)
#训练50轮
epochs_num = 50

d2l ライブラリの train 関数をデータセットに適応するように書き換えます。 

def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # Sum of training loss, sum of training accuracy, no. of examples,
        # no. of predictions
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = d2l.train_batch_ch13(
                net, features, labels.long(), loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')

トレーニングを開始する 

train_ch13(model, train_loader, val_loader, lossf, optimizer, epochs_num)

モデルのトレーニング結果は次のとおりで、テスト セットの精度は約 83% です。


要約する

SegNet は Encoder-Decoder 構造を使用しており、FCN ネットワークと比較して SegNet モデルは小さく、アップサンプリングの特徴回復では、プーリング インデックスを使用して画像の解像度を復元し、より詳細なセグメンテーション結果を取得します。

おすすめ

転載: blog.csdn.net/yumaomi/article/details/124766321