Deep learning image segmentation PSPNet paper reproduction (training and testing visualization)

1. Introduction to PSPNet

PSPNet (Pyramid Scene Parsing Network) comes from an article in CVPR2017. The Chinese translation is Pyramid Scene Parsing Network, which is mainly used for image segmentation. This architecture introduces a Pyramid Pooling module to capture contextual information at different scales. Pyramid Pooling can extract global and local context information at different scales, helping to better understand the semantic content in images, thereby improving segmentation performance.

1. Principle explanation

PSPNet framework diagram

  • (a) Input image
  • (b) Use the pre-trained ResNet model to obtain the feature map
  • (c) Use Pyramid Pooling to obtain representations of different sub-regions, and form feature representations containing local and global context information through upsampling and concat.
  • (d) Send the features to the convolution layer to obtain pixel-level prediction results

2. Paper explanation

The pyramid pooling module combines features at four different pyramid scales. Marked in red is global pooling to produce a single bin output. The following pyramid levels divide the feature map into different sub-regions and form pooled representations at different locations. The output of different levels in the pyramid pooling module contains feature maps of different sizes. In order to maintain the weight of global features, a 1 × 1 convolutional layer is used after each pyramid level to reduce the dimensionality of the context representation to 1/N of the original dimension (if the pyramid level size is N).

The low-dimensional feature map is then upsampled , and features of the same size as the original feature map are obtained through bilinear interpolation . Finally, the features at different levels are connected as the global features of the final pyramid pooling output.

3. Network model

class PSPNet(BaseModel):
    def __init__(self, num_classes, in_channels=3, backbone='resnet152', pretrained=True, use_aux=True, freeze_bn=False, freeze_backbone=False):
        super(PSPNet, self).__init__()
        norm_layer = nn.BatchNorm2d  # 用于规范化的层类型

        # 使用getattr根据backbone参数选择合适的骨干网络模型,并可能加载预训练权重
        model = getattr(resnet, backbone)(pretrained, norm_layer=norm_layer)
        m_out_sz = model.fc.in_features  # 提取骨干网络的输出特征通道数

        self.use_aux = use_aux  # 是否使用辅助分割分支

        # 初始卷积层,根据in_channels来调整输入通道数
        self.initial = nn.Sequential(*list(model.children())[:4])
        if in_channels != 3:
            self.initial[0] = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.initial = nn.Sequential(*self.initial)

        # 骨干网络的不同层
        self.layer1 = model.layer1
        self.layer2 = model.layer2
        self.layer3 = model.layer3
        self.layer4 = model.layer4

        # 主要分割分支,包括特征融合和分割输出
        self.master_branch = nn.Sequential(
            PSPModule(m_out_sz, bin_sizes=[1, 2, 3, 6], norm_layer=norm_layer),  # 特征融合模块
            nn.Conv2d(m_out_sz // 4, num_classes, kernel_size=1)  # 分割输出卷积层
        )

        # 辅助分割分支,可选,用于训练时帮助主分割任务
        self.auxiliary_branch = nn.Sequential(
            nn.Conv2d(m_out_sz // 2, m_out_sz // 4, kernel_size=3, padding=1, bias=False),
            norm_layer(m_out_sz // 4),
            nn.ReLU(inplace=True),
            nn.Dropout2d(0.1),
            nn.Conv2d(m_out_sz // 4, num_classes, kernel_size=1)
        )

        # 初始化网络权重
        initialize_weights(self.master_branch, self.auxiliary_branch)

    def forward(self, x):
        input_size = (x.size()[2], x.size()[3])  # 记录输入图像的尺寸

        x = self.initial(x)  # 初始卷积层
        x = self.layer1(x)  # 第一层
        x = self.layer2(x)  # 第二层
        x_aux = self.layer3(x)  # 第三层,用于辅助分割分支
        x = self.layer4(x)  # 第四层

        output = self.master_branch(x)  # 主要分割分支
        output = F.interpolate(output, size=input_size, mode='bilinear')  # 插值操作,将分割输出大小调整为输入大小
        output = output[:, :, :input_size[0], :input_size[1]]  # 调整输出的尺寸以匹配输入

        # 如果在训练模式下且使用辅助分割分支,还生成辅助分割输出
        if self.training and self.use_aux:
            aux = self.auxiliary_branch(x_aux)
            aux = F.interpolate(aux, size=input_size, mode='bilinear')  # 调整辅助分割输出大小
            aux = aux[:, :, :input_size[0], :input_size[1]]  # 调整输出的尺寸以匹配输入
            return output, aux  # 返回主分割输出和辅助分割输出
        return output  # 只返回主分割输出

其中,PSPModule类的定义如下
class PSPModule(nn.Module):
    def __init__(self, in_channels, bin_sizes, norm_layer):
        super(_PSPModule, self).__init__()

        # 计算每个池化分支的输出通道数
        out_channels = in_channels // len(bin_sizes)

        # 创建池化分支,将它们存储在一个 ModuleList 中
        self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer) 
                                                        for b_s in bin_sizes])

        # 创建特征融合模块(bottleneck)
        self.bottleneck = nn.Sequential(
            nn.Conv2d(in_channels + (out_channels * len(bin_sizes)), out_channels, 
                                    kernel_size=3, padding=1, bias=False),  # 卷积层
            norm_layer(out_channels),  # 规范化层
            nn.ReLU(inplace=True),  # ReLU激活函数
            nn.Dropout2d(0.1)  # 2D Dropout层
        )

    def _make_stages(self, in_channels, out_channels, bin_sz, norm_layer):
        # 创建池化分支的内部结构
        prior = nn.AdaptiveAvgPool2d(output_size=bin_sz)  # 自适应平均池化层
        conv = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)  # 卷积层
        bn = norm_layer(out_channels)  # 规范化层
        relu = nn.ReLU(inplace=True)  # ReLU激活函数
        return nn.Sequential(prior, conv, bn, relu)  # 返回池化分支的Sequential模块

    def forward(self, features):
        h, w = features.size()[2], features.size()[3]  # 获取输入特征的高度和宽度

        pyramids = [features]  # 存储原始特征到金字塔中

        # 遍历每个池化分支,对特征进行插值操作并存储在金字塔中
        pyramids.extend([F.interpolate(stage(features), size=(h, w), mode='bilinear', 
                                        align_corners=True) for stage in self.stages])

        # 将金字塔中的特征拼接在一起并通过特征融合模块
        output = self.bottleneck(torch.cat(pyramids, dim=1))
        return output  # 返回特征融合后的输出

This class is used to perform pyramid pooling and feature fusion operations and fuse them into a feature representation with richer semantic information.

The core idea of ​​PSPNet is to use a 4-level pyramid structure. The pooling kernel can cover the whole, half and small portions of the image, that is,

 self.stages = nn.ModuleList([self._make_stages(in_channels, out_channels, b_s, norm_layer) 
                                                  for b_s in bin_sizes])

In the above code, self.stages contains multiple pooling branches, and bin_sizes is a list containing 4 elements, corresponding to 4 different pooling branches. By traversing bin_sizes, 4 pooling branches are created using the make_stages method. Each pooling branch consists of an adaptive average pooling layer, a convolutional layer, a normalization layer and a ReLU activation layer. This forms the feature extraction part of the pyramid structure.

The four pooling branches have different receptive field sizes to capture image information at different scales.

In other words, each element in self.stages represents a level in the pyramid, reflecting the 4-level pyramid structure. The forward method will traverse these pooling branches and perform interpolation operations on the input features, resizing them to the same size as the original features for feature fusion.

2. Deployment and implementation

My environment is

  • Operating system: win11
  • Language: python3.10
  • IDE:PyCharm 2023
  • GPU:RTX 4060

1、PASCAL VOC 2012

Dataset uses the classic PASCAL VOC 2012, a standard data set for computer vision research, which provides image data and related annotations for a variety of tasks. It contains 20 different object categories such as airplanes, bicycles, cars, dogs, cats, chairs, etc., as well as a category "background". These images are collected from the real world, cover different scenes and angles, and represent common everyday objects. In terms of size, it contains 1,464 training images, 1,449 verification images, and 1,456 test images. Each image comes with detailed annotation information, including bounding boxes for each object instance (object detection task) and pixel-level semantic segmentation labels (semantic segmentation task).

2. Model training

核心训练部分的代码如下:
def _train_epoch(self, epoch):
    self.logger.info('\n')  # 打印日志信息

    self.model.train()  # 设置模型为训练模式
    if self.config['arch']['args']['freeze_bn']:  
        if isinstance(self.model, torch.nn.DataParallel):
            self.model.module.freeze_bn() 
        else:
            self.model.freeze_bn()  
    self.wrt_mode = 'train'  # 设置写入模式为'train'

    tic = time.time()  # 记录当前时间
    self._reset_metrics()  # 重置度量指标
    tbar = tqdm(self.train_loader, ncols=130)  # 创建一个进度条以迭代训练数据集

    for batch_idx, (data, target) in enumerate(tbar):  # 遍历训练数据
        self.data_time.update(time.time() - tic)  # 更新数据加载时间

        self.lr_scheduler.step(epoch=epoch - 1)  # 根据当前训练的epoch调整学习率

        # LOSS & OPTIMIZE
        self.optimizer.zero_grad()  # 清零梯度
        output = self.model(data)  # 前向传播,获取模型输出
        if self.config['arch']['type'][:3] == 'PSP':  
            assert output[0].size()[2:] == target.size()[1:]  # 检查输出和目标的空间尺寸匹配
            assert output[0].size()[1] == self.num_classes  # 检查输出通道数与类别数匹配
            loss = self.loss(output[0], target)  # 计算损失
            loss += self.loss(output[1], target) * 0.4  # 添加辅助损失,加权为0.4
            output = output[0]  # 将主要输出作为最终输出
        else:
            assert output.size()[2:] == target.size()[1:]  
            assert output.size()[1] == self.num_classes  
            loss = self.loss(output, target) 

        if isinstance(self.loss, torch.nn.DataParallel):  
            loss = loss.mean()  # 计算损失的均值
        loss.backward()  # 反向传播,计算梯度
        self.optimizer.step()  # 更新模型参数
        self.total_loss.update(loss.item())  # 更新总损失

        # measure elapsed time
        self.batch_time.update(time.time() - tic)  # 更新批次处理时间
        tic = time.time()

        # LOGGING & TENSORBOARD
        if batch_idx % self.log_step == 0:  # 每隔一定步数记录一次日志和TensorBoard
            self.wrt_step = (epoch - 1) * len(self.train_loader) + batch_idx  # 当前步数
            self.writer.add_scalar(f'{
      
      self.wrt_mode}/loss', loss.item(), self.wrt_step)  # 记录损失到TensorBoard

        # FOR EVAL
        seg_metrics = eval_metrics(output, target, self.num_classes)  # 计算分割度量指标
        self._update_seg_metrics(*seg_metrics)  # 更新分割度量指标
        pixAcc, mIoU, _ = self._get_seg_metrics().values()  # 获取分割指标值

        # PRINT INFO
        tbar.set_description('TRAIN ({}) | Loss: {:.3f} | Acc {:.2f} mIoU {:.2f} | B {:.2f} D {:.2f} |'.format(
                                            epoch, self.total_loss.average, 
                                            pixAcc, mIoU,
                                            self.batch_time.average, self.data_time.average))  # 打印训练信息

    # METRICS TO TENSORBOARD
    seg_metrics = self._get_seg_metrics()
    for k, v in list(seg_metrics.items())[:-1]:  # 遍历分割度量指标并记录
        self.writer.add_scalar(f'{
      
      self.wrt_mode}/{
      
      k}', v, self.wrt_step)
    for i, opt_group in enumerate(self.optimizer.param_groups):  # 记录学习率
        self.writer.add_scalar(f'{
      
      self.wrt_mode}/Learning_rate_{
      
      i}', opt_group['lr'], self.wrt_step)

    # RETURN LOSS & METRICS
    log = {
    
    'loss': self.total_loss.average,  # 返回平均损失
            **seg_metrics}  # 返回分割度量指标

    return log  # 返回日志信息
交叉验证部分,我们进行以下的定义
def _valid_epoch(self, epoch):
    if self.val_loader is None:
        self.logger.warning('Not data loader was passed for the validation step, No validation is performed !')
        return {
    
    }  # 如果没有提供验证数据加载器,发出警告并返回一个空字典
    self.logger.info('\n###### EVALUATION ######')

    self.model.eval()  # 设置模型为评估(验证)模式
    self.wrt_mode = 'val'  # 设置写入模式为'val'(用于TensorBoard记录)

    self._reset_metrics()  # 重置度量指标
    tbar = tqdm(self.val_loader, ncols=130)  # 创建一个进度条用于遍历验证数据集
    with torch.no_grad():  # 禁用梯度计算
        val_visual = []  # 用于可视化的图像列表
        for batch_idx, (data, target) in enumerate(tbar):
            #data, target = data.to(self.device), target.to(self.device)  # 将数据和目标移到指定的设备上(通常是GPU)
            # LOSS
            output = self.model(data)  # 前向传播,获取模型的输出
            loss = self.loss(output, target)  # 计算损失
            if isinstance(self.loss, torch.nn.DataParallel):  # 如果损失函数是DataParallel损失函数
                loss = loss.mean()  # 计算损失的均值
            self.total_loss.update(loss.item())  # 更新总损失

            seg_metrics = eval_metrics(output, target, self.num_classes)  # 计算分割度量指标
            self._update_seg_metrics(*seg_metrics)  # 更新分割度量指标

            # LIST OF IMAGE TO VIZ (15 images)
            if len(val_visual) < 15:  # 用于可视化的图像数量限制在15张以内
                target_np = target.data.cpu().numpy()  # 将目标从GPU移到CPU并转换为NumPy数组
                output_np = output.data.max(1)[1].cpu().numpy()  # 将模型输出的类别概率最大的类别作为预测结果
                val_visual.append([data[0].data.cpu(), target_np[0], output_np[0]])  # 添加可视化所需的图像和标签

            # PRINT INFO
            pixAcc, mIoU, _ = self._get_seg_metrics().values()  # 获取分割度量指标的值
            tbar.set_description('EVAL ({}) | Loss: {:.3f}, PixelAcc: {:.2f}, Mean IoU: {:.2f} |'.format( epoch,
                                            self.total_loss.average,
                                            pixAcc, mIoU))  # 打印验证信息

        # WRITING & VISUALIZING THE MASKS
        val_img = []  # 用于可视化的图像列表
        palette = self.train_loader.dataset.palette  # 获取调色板信息
        for d, t, o in val_visual:  # 遍历可视化图像列表
            d = self.restore_transform(d)  # 还原图像的转换(例如,去均值、缩放等)
            t, o = colorize_mask(t, palette), colorize_mask(o, palette)  # 将标签和模型输出转换为彩色掩码
            d, t, o = d.convert('RGB'), t.convert('RGB'), o.convert('RGB')  # 将图像转换为RGB格式
            [d, t, o] = [self.viz_transform(x) for x in [d, t, o]]  # 应用可视化转换
            val_img.extend([d, t, o])  # 添加可视化图像到列表中
        val_img = torch.stack(val_img, 0)  # 将可视化图像堆叠成一个张量
        val_img = make_grid(val_img.cpu(), nrow=3, padding=5)  # 使用Grid方式排列可视化图像
        self.writer.add_image(f'{
      
      self.wrt_mode}/inputs_targets_predictions', val_img, self.wrt_step)  # 将可视化图像写入TensorBoard

        # METRICS TO TENSORBOARD
        self.wrt_step = (epoch) * len(self.val_loader)  # 计算当前步数
        self.writer.add_scalar(f'{
      
      self.wrt_mode}/loss', self.total_loss.average, self.wrt_step)  # 记录平均损失到TensorBoard
        seg_metrics = self._get_seg_metrics()  # 获取分割度量指标
        for k, v in list(seg_metrics.items())[:-1]:  # 遍历分割度量指标并记录到TensorBoard
            self.writer.add_scalar(f'{
      
      self.wrt_mode}/{
      
      k}', v, self.wrt_step)

        log = {
    
    
            'val_loss': self.total_loss.average,  # 返回平均验证损失
            **seg_metrics  # 返回分割度量指标
        }

    return log  # 返回日志信息

# 以下是用于度量指标的辅助函数
def _reset_metrics(self):
    self.batch_time = AverageMeter()  # 用于记录批次处理时间的平均值
    self.data_time = AverageMeter()  # 用于记录数据加载时间的平均值
    self.total_loss = AverageMeter()  # 用于记录总损失的平均值
    self.total_inter, self.total_union = 0, 0  # 用于记录交集和并集的总和
    self.total_correct, self.total_label = 0, 0  # 用于记录正确分类和标签的总和

def _update_seg_metrics(self, correct, labeled, inter, union):
    self.total_correct += correct  # 更新正确分类的数量
    self.total_label += labeled  # 更新标签的数量
    self.total_inter += inter  # 更新交集的总和
    self.total_union += union  # 更新并集的总和

def _get_seg_metrics(self):
    pixAcc = 1.0 * self.total_correct / (np.spacing(1) + self.total_label)  # 计算像素准确率
    IoU = 1.0 * self.total_inter / (np.spacing(1) + self.total_union)  # 计算各类别的IoU
    mIoU = IoU.mean()  # 计算平均IoU
    return {
    
    
        "Pixel_Accuracy": np.round(pixAcc, 3),  # 返回像素准确率
        "Mean_IoU": np.round(mIoU, 3),  # 返回平均IoU
        "Class_IoU": dict(zip(range(self.num_classes), np.round(IoU, 3)))  # 返回各类别的IoU
    }

3. Metrics

Here we use two metrics to evaluate the performance of the model:

Pixel_Accuracy is a metric used to evaluate the performance of image segmentation tasks. It measures the proportion of the number of pixels correctly classified by the model to the total number of pixels across the entire image. It can be expressed simply as the following mathematical formula:

Insert image description here
in:

  • “Number of Correctly Classified Pixels” indicates the number of pixels that the model correctly classified in the segmented image.
  • “Total Number of Pixels” represents the total number of pixels in the entire segmented image.

Pixel Accuracy ranges from 0 to 1, where 1 means that the model correctly classified all pixels across the entire image, and 0 means that the model did not correctly classify any pixels.

Mean_IoU : IoU (Intersection over Union) is an indicator that indicates the degree of overlap between two sets, and is usually used in segmentation tasks. In segmentation tasks, one set represents the model's predicted segmentation regions, and the other set represents the true segmentation regions. The calculation formula of IoU is as follows:

Insert image description here

in:

  • Area of ​​Intersection" is the intersection area between the model's predicted segmentation area and the real segmentation area.
  • Area of ​​Union" is the union area of ​​the model's predicted segmentation area and the real segmentation area.

or
Insert image description here
where:

  • TP (True Positives): Indicates the number of pixels correctly predicted by the model to be positive (target category)
  • FP (False Positives): Indicates the number of pixels in which the model incorrectly predicts background pixels as positive.
  • FN (False Negatives): Indicates the number of pixels in which the model incorrectly predicts positive pixels as background

Mean Intersection over Union (Mean IoU) is the average of all categories of IoU

Insert image description here
Among them, N is the number of categories, IoU_i is the IoU of the i-th category

4. Result analysis

The main parameter settings are as follows:

"epochs": 80,
"loss": "CrossEntropyLoss2d",
"batch_size": 8,
"base_size": 400,  //图像大小调整为base_size,然后随机裁剪
"crop_size": 380,  //重新缩放后随机裁剪的大小
 "optimizer": {
    
    
        "type": "SGD",
        "differential_lr": true,
        "args":{
    
    
            "lr": 0.01,
            "weight_decay": 1e-4,
            "momentum": 0.9
        }
    },

Due to limited GPU resources, only 80 epochs are run here. The log information obtained is as follows: Information
Insert image description here
recorded by Tensorboard: Comparison of Input, Ground Truth and Output of
the train
Insert image description here
validation
Insert image description here
cross-validation set:
Insert image description here
Some details are not completely segmented, but the image can be recognized main body. There is not much difference between train_loss and val_loss, and there is no overfitting. Increasing the training period may achieve better results.

5. Image test

The test part code is as follows:

    args = parse_arguments()  # 解析命令行参数
    config = json.load(open(args.config))  # 从JSON文件中加载配置信息

    # 根据配置信息创建数据加载器
    loader = getattr(dataloaders, config['train_loader']['type'])(**config['train_loader']['args'])
    to_tensor = transforms.ToTensor()  # 创建图像到张量的转换
    normalize = transforms.Normalize(loader.MEAN, loader.STD)  # 创建归一化转换
    num_classes = loader.dataset.num_classes  # 获取数据集中的类别数量
    palette = loader.dataset.palette  # 获取颜色映射表

    # 创建模型
    model = getattr(models, config['arch']['type'])(num_classes, **config['arch']['args'])  # 根据配置创建模型
    availble_gpus = list(range(torch.cuda.device_count()))  # 获取可用的GPU列表
    device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu')  # 选择运行设备(GPU或CPU)

    # 加载模型检查点
    checkpoint = torch.load(args.model, map_location=device)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():
        checkpoint = checkpoint['state_dict']
    
    # 如果在训练期间使用了数据并行,需要处理模型
    if 'module' in list(checkpoint.keys())[0] and not isinstance(model, torch.nn.DataParallel):
        # 对于GPU推理,使用数据并行
        if "cuda" in device.type:
            model = torch.nn.DataParallel(model)
        else:
            # 对于CPU推理,移除模型的"module"前缀
            new_state_dict = OrderedDict()
            for k, v in checkpoint.items():
                name = k[7:]
                new_state_dict[name] = v
            checkpoint = new_state_dict
    
    # 加载模型权重
    model.load_state_dict(checkpoint)
    model.to(device)  # 将模型移动到所选设备
    model.eval()  # 设置模型为评估模式

    # 创建输出目录
    if not os.path.exists('outputs'):
        os.makedirs('outputs')

    # 获取图像文件列表
    image_files = sorted(glob(os.path.join(args.images, f'*.{
      
      args.extension}')))
    with torch.no_grad():
        tbar = tqdm(image_files, ncols=100)  # 创建进度条
        for img_file in tbar:
            image = Image.open(img_file).convert('RGB')  # 打开图像并将其转换为RGB格式
            input = normalize(to_tensor(image)).unsqueeze(0)  # 转换图像并添加批次维度

			#预测图像分割结果
			prediction = multi_scale_predict(model, input, scales, num_classes, device)
         
            prediction = F.softmax(torch.from_numpy(prediction), dim=0).argmax(0).cpu().numpy()  # 计算最终的预测结果
            save_images(image, prediction, args.output, img_file, palette)  # 保存预测结果的图像

The multi-scale image prediction function is defined as:

def multi_scale_predict(model, image, scales, num_classes, device, flip=False):
    # 获取输入图像的尺寸
    input_size = (image.size(2), image.size(3))
    # 创建上采样层,用于将不同尺度的预测结果恢复到原始尺寸
    upsample = nn.Upsample(size=input_size, mode='bilinear', align_corners=True)
    # 初始化用于累计预测结果的数组
    total_predictions = np.zeros((num_classes, image.size(2), image.size(3)))

    # 将输入图像转换为NumPy数组,并移动到CPU上
    image = image.data.data.cpu().numpy()
    
    # 遍历不同的尺度
    for scale in scales:
        # 缩放图像
        scaled_img = ndimage.zoom(image, (1.0, 1.0, float(scale), float(scale)), order=1, prefilter=False)
        # 将缩放后的图像转换为PyTorch张量并移动到指定设备
        scaled_img = torch.from_numpy(scaled_img).to(device)
        # 使用模型进行预测并上采样到原始尺寸
        scaled_prediction = upsample(model(scaled_img).cpu())

        # 如果启用了翻转,对翻转后的图像进行预测并平均
        if flip:
            fliped_img = scaled_img.flip(-1).to(device)
            fliped_predictions = upsample(model(fliped_img).cpu())
            scaled_prediction = 0.5 * (fliped_predictions.flip(-1) + scaled_prediction)
        
        # 将当前尺度的预测结果累加到总体预测中
        total_predictions += scaled_prediction.data.cpu().numpy().squeeze(0)

    # 计算平均预测结果
    total_predictions /= len(scales)
    return total_predictions

We arbitrarily specify an input image and test the segmentation effect of the model.
Insert image description here
The effect is not good. If you have the resources, you can increase the number of epoch training and try different data sets
———————————————— ——————————————————————————

Guess you like

Origin blog.csdn.net/LPYchengxuyuan/article/details/133418489