[Condinst]Conditional Convolutions for Instance Segmentation笔记

Conditional Convolutions for Instance Segmentation

如果对你帮助的话,希望给我个赞~

网络结构

在这里插入图片描述

mask head

在这里插入图片描述
总的来说,Condinst == FCOS(cls + reg + ctrness) + FCOS Head的top_feats(也就是dynamic_mask_head, channel: 256 --> 169) + 从FPN(论文里是P3层,不过我看代码的self.in_features是[‘p3’, ‘p4’, ‘p5’]? 接着引入refine结构,然后在一起做一个sum。然后引入tower结构,channel: 128 --> 8)。
top_feats,refine,tower module这三个网络结构见:

'''
top_feats
	 in CondInst:
	 (Pdb) top_feats[0].size()
	 torch.Size([2, 169, 100, 152])
	 (Pdb) top_feats[1].size()
	 torch.Size([2, 169, 50, 76])
	 (Pdb) top_feats[2].size()
	 torch.Size([2, 169, 25, 38])
	 (Pdb) top_feats[3].size()
	 torch.Size([2, 169, 13, 19])
	 (Pdb) top_feats[4].size()
	 torch.Size([2, 169, 7, 10])
     
 '''
'''
MaskBranch(
  (refine): ModuleList(
    (0): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (tower): Sequential(
    (0): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (4): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))
  )
)

'''


LOSS

在这里插入图片描述
在这里插入图片描述

1. AdelaiDet/adet/modeling/condinst/condinst.py

# -*- coding: utf-8 -*-
import logging

import torch
from torch import nn
import torch.nn.functional as F

from detectron2.structures import ImageList
from detectron2.modeling.proposal_generator import build_proposal_generator
from detectron2.modeling.backbone import build_backbone
from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY
from detectron2.structures.instances import Instances
from detectron2.structures.masks import PolygonMasks, polygons_to_bitmask

from .dynamic_mask_head import build_dynamic_mask_head
from .mask_branch import build_mask_branch

from adet.utils.comm import aligned_bilinear
import pdb
__all__ = ["CondInst"]


logger = logging.getLogger(__name__)


@META_ARCH_REGISTRY.register()
class CondInst(nn.Module):
    """
    Main class for CondInst architectures (see https://arxiv.org/abs/2003.05664).
    """

    def __init__(self, cfg):
        super().__init__()
        self.device = torch.device(cfg.MODEL.DEVICE) # CUDA

        self.backbone = build_backbone(cfg) # build_fcos_resnet_fpn_backbone
        self.proposal_generator = build_proposal_generator(cfg, self.backbone.output_shape()) # FCOS
        self.mask_head = build_dynamic_mask_head(cfg) # CondInst mask_head
        self.mask_branch = build_mask_branch(cfg, self.backbone.output_shape()) # ConInst mask_branch
        self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE # 4 downsampling
        self.max_proposals = cfg.MODEL.CONDINST.MAX_PROPOSALS # -1

        # build top module
        in_channels = self.proposal_generator.in_channels_to_top_module  # 256

        self.controller = nn.Conv2d( # [256, 169]
            in_channels, self.mask_head.num_gen_params,
            kernel_size=3, stride=1, padding=1
        )
        torch.nn.init.normal_(self.controller.weight, std=0.01)
        torch.nn.init.constant_(self.controller.bias, 0)

        pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1) 
        pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
        self.normalizer = lambda x: (x - pixel_mean) / pixel_std
        self.to(self.device) # 加入cuda
        pdb.set_trace()

    def forward(self, batched_inputs): 
        images = [x["image"].to(self.device) for x in batched_inputs] # images放入device  
        images = [self.normalizer(x) for x in images]
        images = ImageList.from_tensors(images, self.backbone.size_divisibility) # torch.Size([2, 3, 768, 1248])
        pdb.set_trace()
        features = self.backbone(images.tensor) # forward build_fcos_resnet_fpn_backbone len = 5

        if "instances" in batched_inputs[0]:
            gt_instances = [x["instances"].to(self.device) for x in batched_inputs] # len(gt_instances) = batch_size ,一共有gt_instances[0:batch_size]
            self.add_bitmasks(gt_instances, images.tensor.size(-2), images.tensor.size(-1))
        else:
            gt_instances = None
        
        pdb.set_trace()
        mask_feats, sem_losses = self.mask_branch(features, gt_instances) # forward mask_branch

        proposals, proposal_losses = self.proposal_generator( # forward FCOS
            images, features, gt_instances, self.controller
        )

        if self.training:
            loss_mask = self._forward_mask_heads_train(proposals, mask_feats, gt_instances) # 调用_forward_mask_heads_train

            losses = {
    
    }
            losses.update(sem_losses)
            losses.update(proposal_losses)
            losses.update({
    
    "loss_mask": loss_mask})
            pdb.set_trace()

            return losses
        else: # test
            pred_instances_w_masks = self._forward_mask_heads_test(proposals, mask_feats) # 调用 _forward_mask_heads_test

            padded_im_h, padded_im_w = images.tensor.size()[-2:]
            processed_results = []
            for im_id, (input_per_image, image_size) in enumerate(zip(batched_inputs, images.image_sizes)):
                height = input_per_image.get("height", image_size[0])
                width = input_per_image.get("width", image_size[1])

                instances_per_im = pred_instances_w_masks[pred_instances_w_masks.im_inds == im_id]
                instances_per_im = self.postprocess( # 调用 postprocess
                    instances_per_im, height, width,
                    padded_im_h, padded_im_w
                )

                processed_results.append({
    
    
                    "instances": instances_per_im
                })
            return processed_results

    def _forward_mask_heads_train(self, proposals, mask_feats, gt_instances):
        # prepare the inputs for mask heads
        pred_instances = proposals["instances"] # len  160

        if 0 <= self.max_proposals < len(pred_instances): # self.max_proposals 500
            inds = torch.randperm(len(pred_instances), device=mask_feats.device).long()
            logger.info("clipping proposals from {} to {}".format(
                len(pred_instances), self.max_proposals
            ))
            pred_instances = pred_instances[inds[:self.max_proposals]]

        pred_instances.mask_head_params = pred_instances.top_feats # [160, 169]

        loss_mask = self.mask_head(
            mask_feats, self.mask_branch.out_stride,
            pred_instances, gt_instances
        )
        pdb.set_trace()
        return loss_mask

    def _forward_mask_heads_test(self, proposals, mask_feats):
        # prepare the inputs for mask heads
        for im_id, per_im in enumerate(proposals):
            per_im.im_inds = per_im.locations.new_ones(len(per_im), dtype=torch.long) * im_id
        pred_instances = Instances.cat(proposals)
        pred_instances.mask_head_params = pred_instances.top_feat
        pdb.set_trace()
        pred_instances_w_masks = self.mask_head( # call DynamicMaskHead()
            mask_feats, self.mask_branch.out_stride, pred_instances
        )
        pdb.set_trace()
        return pred_instances_w_masks

    def add_bitmasks(self, instances, im_h, im_w):
        for per_im_gt_inst in instances:
            if not per_im_gt_inst.has("gt_masks"):
                continue
            start = int(self.mask_out_stride // 2)
            if isinstance(per_im_gt_inst.get("gt_masks"), PolygonMasks):
                polygons = per_im_gt_inst.get("gt_masks").polygons
                per_im_bitmasks = []
                per_im_bitmasks_full = []
                for per_polygons in polygons:
                    bitmask = polygons_to_bitmask(per_polygons, im_h, im_w)
                    bitmask = torch.from_numpy(bitmask).to(self.device).float()
                    start = int(self.mask_out_stride // 2)
                    bitmask_full = bitmask.clone()
                    bitmask = bitmask[start::self.mask_out_stride, start::self.mask_out_stride]

                    assert bitmask.size(0) * self.mask_out_stride == im_h
                    assert bitmask.size(1) * self.mask_out_stride == im_w

                    per_im_bitmasks.append(bitmask)
                    per_im_bitmasks_full.append(bitmask_full)

                per_im_gt_inst.gt_bitmasks = torch.stack(per_im_bitmasks, dim=0)
                per_im_gt_inst.gt_bitmasks_full = torch.stack(per_im_bitmasks_full, dim=0)
            else: # RLE format bitmask
                bitmasks = per_im_gt_inst.get("gt_masks").tensor
                h, w = bitmasks.size()[1:]
                # pad to new size
                bitmasks_full = F.pad(bitmasks, (0, im_w - w, 0, im_h - h), "constant", 0)
                bitmasks = bitmasks_full[:, start::self.mask_out_stride, start::self.mask_out_stride]
                per_im_gt_inst.gt_bitmasks = bitmasks
                per_im_gt_inst.gt_bitmasks_full = bitmasks_full

    def postprocess(self, results, output_height, output_width, padded_im_h, padded_im_w, mask_threshold=0.5):
        """
        Resize the output instances.
        The input images are often resized when entering an object detector.
        As a result, we often need the outputs of the detector in a different
        resolution from its inputs.
        This function will resize the raw outputs of an R-CNN detector
        to produce outputs according to the desired output resolution.
        Args:
            results (Instances): the raw outputs from the detector.
                `results.image_size` contains the input image resolution the detector sees.
                This object might be modified in-place.
            output_height, output_width: the desired output resolution.
        Returns:
            Instances: the resized output from the model, based on the output resolution
        """
        scale_x, scale_y = (output_width / results.image_size[1], output_height / results.image_size[0])
        resized_im_h, resized_im_w = results.image_size
        results = Instances((output_height, output_width), **results.get_fields())

        if results.has("pred_boxes"):
            output_boxes = results.pred_boxes
        elif results.has("proposal_boxes"):
            output_boxes = results.proposal_boxes

        output_boxes.scale(scale_x, scale_y)
        output_boxes.clip(results.image_size)

        results = results[output_boxes.nonempty()]

        if results.has("pred_global_masks"):
            mask_h, mask_w = results.pred_global_masks.size()[-2:]
            factor_h = padded_im_h // mask_h
            factor_w = padded_im_w // mask_w
            assert factor_h == factor_w
            factor = factor_h
            pred_global_masks = aligned_bilinear(
                results.pred_global_masks, factor
            )
            pred_global_masks = pred_global_masks[:, :, :resized_im_h, :resized_im_w]
            pred_global_masks = F.interpolate(
                pred_global_masks,
                size=(output_height, output_width),
                mode="bilinear", align_corners=False
            )
            pred_global_masks = pred_global_masks[:, 0, :, :]
            results.pred_masks = (pred_global_masks > mask_threshold).float()

        return results

'''
(Pdb) gt_instances
[Instances(num_instances=5, image_height=768, image_width=1229, fields=[gt_boxes: Boxes(tensor([[ 788.3651,  355.6032, 1102.0674,  613.4592],
        [ 157.3120,  426.8160,  239.3862,  499.2768],
        [ 234.8158,  432.5568,  293.6734,  479.7504],
        [ 373.0399,  401.1456,  441.9791,  500.7936],
        [ 312.8381,  432.5568,  346.6740,  450.7008]], device='cuda:0')), gt_classes: tensor([19, 19, 19, 19, 19], device='cuda:0'), gt_masks: PolygonMasks(num_instances=5)]), Instances(num_instances=4, image_height=704, image_width=939, fields=[gt_boxes: Boxes(tensor([[  3.6973,  25.3147, 939.0000, 704.0000],
        [ 50.9261, 177.0707,  87.4297, 230.3987],
        [ 86.6374, 220.0147, 137.6222, 252.9413],
        [ 61.4458, 222.3320, 104.6105, 242.7773]], device='cuda:0')), gt_classes: tensor([59, 41, 65, 65], device='cuda:0'), gt_masks: PolygonMasks(num_instances=4)])]
(Pdb) len(gt_instances)
2

'''


'''
(Pdb) batched_inputs[0]['image'].size()
torch.Size([3, 768, 1229])
(Pdb) batched_inputs[1]['image'].size()
torch.Size([3, 704, 939])

(Pdb) batched_inputs[0].keys()
dict_keys(['file_name', 'height', 'width', 'image_id', 'image', 'instances'])

'''
'''
(Pdb) features['p3'].size()
torch.Size([2, 256, 96, 156])
(Pdb) features['p4'].size()
torch.Size([2, 256, 48, 78])
(Pdb) features['p5'].size()
torch.Size([2, 256, 24, 39])
(Pdb) features['p6'].size()
torch.Size([2, 256, 12, 20])
(Pdb) features['p7'].size()
torch.Size([2, 256, 6, 10])
(Pdb) 

'''

'''
MaskBranch(
  (refine): ModuleList(
    (0): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
  )
  (tower): Sequential(
    (0): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (1): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (2): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (3): Sequential(
      (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
    )
    (4): Conv2d(128, 8, kernel_size=(1, 1), stride=(1, 1))
  )
)

'''

2. AdelaiDet/adet/modeling/condinst/mask_branch.py

from typing import Dict
import math

import torch
from torch import nn

from fvcore.nn import sigmoid_focal_loss_jit
from detectron2.layers import ShapeSpec

from adet.layers import conv_with_kaiming_uniform
from adet.utils.comm import aligned_bilinear
import pdb

INF = 100000000


def build_mask_branch(cfg, input_shape):
    return MaskBranch(cfg, input_shape)


class MaskBranch(nn.Module):
    def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]):
        super().__init__()
        self.in_features = cfg.MODEL.CONDINST.MASK_BRANCH.IN_FEATURES # ['p3', 'p4', 'p5']
        self.sem_loss_on = cfg.MODEL.CONDINST.MASK_BRANCH.SEMANTIC_LOSS_ON # False
        self.num_outputs = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS # 8
        norm = cfg.MODEL.CONDINST.MASK_BRANCH.NORM # BN
        num_convs = cfg.MODEL.CONDINST.MASK_BRANCH.NUM_CONVS # 4
        channels = cfg.MODEL.CONDINST.MASK_BRANCH.CHANNELS # 128
        self.out_stride = input_shape[self.in_features[0]].stride # 8

        feature_channels = {
    
    k: v.channels for k, v in input_shape.items()}

        conv_block = conv_with_kaiming_uniform(norm, activation=True)

        # refine module
        self.refine = nn.ModuleList()
        for in_feature in self.in_features: # ['p3', 'p4', 'p5']
            self.refine.append(conv_block(
                feature_channels[in_feature],
                channels, 3, 1
            ))
        # tower module
        tower = []
        for i in range(num_convs):
            tower.append(conv_block(
                channels, channels, 3, 1
            ))
        tower.append(nn.Conv2d(
            channels, max(self.num_outputs, 1), 1
        ))
        self.add_module('tower', nn.Sequential(*tower))

        if self.sem_loss_on:  # False
            num_classes = cfg.MODEL.FCOS.NUM_CLASSES
            self.focal_loss_alpha = cfg.MODEL.FCOS.LOSS_ALPHA
            self.focal_loss_gamma = cfg.MODEL.FCOS.LOSS_GAMMA

            in_channels = feature_channels[self.in_features[0]] # 256
            self.seg_head = nn.Sequential(
                conv_block(in_channels, channels, kernel_size=3, stride=1),
                conv_block(channels, channels, kernel_size=3, stride=1)
            )

            self.logits = nn.Conv2d(channels, num_classes, kernel_size=1, stride=1)

            prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
            bias_value = -math.log((1 - prior_prob) / prior_prob)
            torch.nn.init.constant_(self.logits.bias, bias_value)
        pdb.set_trace()
        
    def forward(self, features, gt_instances=None):
        for i, f in enumerate(self.in_features):
            if i == 0: # 第一层的特征作为shortcut
                x = self.refine[i](features[f])
            else:
                x_p = self.refine[i](features[f])

                target_h, target_w = x.size()[2:]
                h, w = x_p.size()[2:]
                assert target_h % h == 0
                assert target_w % w == 0
                factor_h, factor_w = target_h // h, target_w // w
                assert factor_h == factor_w
                x_p = aligned_bilinear(x_p, factor_h)
                x = x + x_p # refine结构做一个残差的连接
        pdb.set_trace()
        mask_feats = self.tower(x) # 将refine结构后的参数传入 tower(x)  eg. torch.Size([2, 8, 128, 100])
        
        if self.num_outputs == 0:
            mask_feats = mask_feats[:, :self.num_outputs]

        losses = {
    
    }


        # auxiliary thing semantic loss  condinst 不使用语义损失
        if self.training and self.sem_loss_on:
            logits_pred = self.logits(self.seg_head(
                features[self.in_features[0]]
            ))
            pdb.set_trace()
            # compute semantic targets
            semantic_targets = []
            for per_im_gt in gt_instances:
                h, w = per_im_gt.gt_bitmasks_full.size()[-2:]
                areas = per_im_gt.gt_bitmasks_full.sum(dim=-1).sum(dim=-1)
                areas = areas[:, None, None].repeat(1, h, w)
                areas[per_im_gt.gt_bitmasks_full == 0] = INF
                areas = areas.permute(1, 2, 0).reshape(h * w, -1)
                min_areas, inds = areas.min(dim=1)
                per_im_sematic_targets = per_im_gt.gt_classes[inds] + 1
                per_im_sematic_targets[min_areas == INF] = 0
                per_im_sematic_targets = per_im_sematic_targets.reshape(h, w)
                semantic_targets.append(per_im_sematic_targets)

            semantic_targets = torch.stack(semantic_targets, dim=0)

            # resize target to reduce memory
            semantic_targets = semantic_targets[
                               :, None, self.out_stride // 2::self.out_stride,
                               self.out_stride // 2::self.out_stride
                               ]

            # prepare one-hot targets
            num_classes = logits_pred.size(1)
            class_range = torch.arange(
                num_classes, dtype=logits_pred.dtype,
                device=logits_pred.device
            )[:, None, None]
            class_range = class_range + 1
            one_hot = (semantic_targets == class_range).float()
            num_pos = (one_hot > 0).sum().float().clamp(min=1.0)

            loss_sem = sigmoid_focal_loss_jit(
                logits_pred, one_hot,
                alpha=self.focal_loss_alpha,
                gamma=self.focal_loss_gamma,
                reduction="sum",
            ) / num_pos
            losses['loss_sem'] = loss_sem
        pdb.set_trace()
        return mask_feats, losses # 注意 不走语义辅助损失
'''
{
    'p3': ShapeSpec(channels=256, height=None, width=None, stride=8),
    'p4': ShapeSpec(channels=256, height=None, width=None, stride=16), 
    'p5': ShapeSpec(channels=256, height=None, width=None, stride=32), 
    'p6': ShapeSpec(channels=256, height=None, width=None, stride=64), 
    'p7': ShapeSpec(channels=256, height=None, width=None, stride=128)
    }

'''

3. AdelaiDet/adet/modeling/condinst/dynamic_mask_head.py

import torch
from torch.nn import functional as F
from torch import nn

from adet.utils.comm import compute_locations, aligned_bilinear
import pdb

def dice_coefficient(x, target): # mask_scores的size是[160, 1, 200, 304] x: [160, 200 * 304] target: [160, 200 * 304]
    eps = 1e-5
    n_inst = x.size(0)
    x = x.reshape(n_inst, -1)
    target = target.reshape(n_inst, -1)
    intersection = (x * target).sum(dim=1) # 160
    union = (x ** 2.0).sum(dim=1) + (target ** 2.0).sum(dim=1) + eps
    loss = 1. - (2 * intersection / union) # [160]
    pdb.set_trace()
    return loss


def parse_dynamic_params(params, channels, weight_nums, bias_nums):
    assert params.dim() == 2
    assert len(weight_nums) == len(bias_nums) # 3
    assert params.size(1) == sum(weight_nums) + sum(bias_nums) # 169

    num_insts = params.size(0) # 160
    num_layers = len(weight_nums) # 3

    params_splits = list(torch.split_with_sizes(
        params, weight_nums + bias_nums, dim=1
    )) # 6

    weight_splits = params_splits[:num_layers] # 3
    bias_splits = params_splits[num_layers:] # 3

    for l in range(num_layers): # 3
        if l < num_layers - 1:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(num_insts * channels, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts * channels)
        else:
            # out_channels x in_channels x 1 x 1
            weight_splits[l] = weight_splits[l].reshape(num_insts * 1, -1, 1, 1)
            bias_splits[l] = bias_splits[l].reshape(num_insts)
    pdb.set_trace()
    return weight_splits, bias_splits # 见下方注释


def build_dynamic_mask_head(cfg):
    return DynamicMaskHead(cfg)


class DynamicMaskHead(nn.Module):
    def __init__(self, cfg):
        # 设置好了参数num_gen_params
        super(DynamicMaskHead, self).__init__()
        self.num_layers = cfg.MODEL.CONDINST.MASK_HEAD.NUM_LAYERS # 3
        self.channels = cfg.MODEL.CONDINST.MASK_HEAD.CHANNELS # 8
        self.in_channels = cfg.MODEL.CONDINST.MASK_BRANCH.OUT_CHANNELS # 8
        self.mask_out_stride = cfg.MODEL.CONDINST.MASK_OUT_STRIDE # 4
        self.disable_rel_coords = cfg.MODEL.CONDINST.MASK_HEAD.DISABLE_REL_COORDS # False

        soi = cfg.MODEL.FCOS.SIZES_OF_INTEREST # size of interest [64, 128, 256, 512] focal的参数 就是每一层中max(l, r, t, b)
        self.register_buffer("sizes_of_interest", torch.tensor(soi + [soi[-1] * 2]))

        weight_nums, bias_nums = [], []  # weights, bias个数
        for l in range(self.num_layers):
            if l == 0:
                if not self.disable_rel_coords:
                    weight_nums.append((self.in_channels + 2) * self.channels) # 8 + 2 = 10 加入rel coord
                else:
                    weight_nums.append(self.in_channels * self.channels)
                bias_nums.append(self.channels)
            elif l == self.num_layers - 1:
                weight_nums.append(self.channels * 1) #  8
                bias_nums.append(1)
            else:
                weight_nums.append(self.channels * self.channels)
                bias_nums.append(self.channels)

        self.weight_nums = weight_nums # [80, 64, 8]
        self.bias_nums = bias_nums # [8, 8, 1]
        self.num_gen_params = sum(weight_nums) + sum(bias_nums)  # 169
        pdb.set_trace()

    def mask_heads_forward(self, features, weights, biases, num_insts):
        '''
        :param features
        :param weights: [w0, w1, ...]
        :param bias: [b0, b1, ...]
        :return:
        '''
        assert features.dim() == 4
        n_layers = len(weights)
        x = features
        for i, (w, b) in enumerate(zip(weights, biases)):
            x = F.conv2d(
                x, w, bias=b,
                stride=1, padding=0,
                groups=num_insts
            )
            if i < n_layers - 1:
                x = F.relu(x)
        pdb.set_trace()
        return x

    def mask_heads_forward_with_coords(
            self, mask_feats, mask_feat_stride, instances
    ):
    # mask_feats torch.Size([2, 8, 100, 152])
    # mask_feat_stride = 8
        locations = compute_locations( # 调用compute_locations
            mask_feats.size(2), mask_feats.size(3),
            stride=mask_feat_stride, device=mask_feats.device
        ) # [15200, 2]
        n_inst = len(instances)

        im_inds = instances.im_inds # 160  160为此次训练的这样本总个数 下同
        mask_head_params = instances.mask_head_params # [160, 169]

        N, _, H, W = mask_feats.size()

        if not self.disable_rel_coords:
            instance_locations = instances.locations # [160, 2]
            relative_coords = instance_locations.reshape(-1, 1, 2) - locations.reshape(1, -1, 2) # [160, 1, 2] - [1, 15200, 2] = [160, 15200, 2]
            pdb.set_trace() # 相对坐标 = 每一个正样本像素点的坐标 - mask_feat上所有像素点的坐标  也就是对于整张图的偏移 
            relative_coords = relative_coords.permute(0, 2, 1).float() # [160, 2, 15200]
            soi = self.sizes_of_interest.float()[instances.fpn_levels] # [64] 下方注释 存储了映射的stride
            relative_coords = relative_coords / soi.reshape(-1, 1, 1) # soi.reshape(-1, 1, 1) --> [160, 1 ,1]  为什么要除以Soi 如何理解?
            relative_coords = relative_coords.to(dtype=mask_feats.dtype)  # torch.Size([160, 2, 15200])

            mask_head_inputs = torch.cat([
                relative_coords, mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)
            ], dim=1) # torch.Size([160, 10, 15200])
            pdb.set_trace()
        else:
            mask_head_inputs = mask_feats[im_inds].reshape(n_inst, self.in_channels, H * W)

        mask_head_inputs = mask_head_inputs.reshape(1, -1, H, W) # torch.Size([1, 1600, 100, 152])

        weights, biases = parse_dynamic_params( # 调用parse_dynamic_params 见下方注释
            mask_head_params, self.channels,
            self.weight_nums, self.bias_nums
        )

        mask_logits = self.mask_heads_forward(mask_head_inputs, weights, biases, n_inst) 


        mask_logits = mask_logits.reshape(-1, 1, H, W) # torch.Size([160, 1, 100, 152])

        assert mask_feat_stride >= self.mask_out_stride
        assert mask_feat_stride % self.mask_out_stride == 0
        mask_logits = aligned_bilinear(mask_logits, int(mask_feat_stride / self.mask_out_stride)) # 插值 torch.Size([160, 1, 200, 304])
        pdb.set_trace()
        return mask_logits.sigmoid() # sigmoid

    def __call__(self, mask_feats, mask_feat_stride, pred_instances, gt_instances=None): # eg. torch.Size([2, 8, 100, 152])  8  160个instnaces 2个gt_instances  gt_instances[0] = 15 gt_instances[1] = 3 
        if self.training:
            gt_inds = pred_instances.gt_inds # [160]
            gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) # 循环batchsize次 gt[0] : [15, 200, 304] gt[1] : [3, 200, 304]
            # 根据索引[160]里的数字是 0-17(见下方注释)来筛选原来gt_bitmasks的某维度(gt_inds[0] = 0 就对于第0维的值),添加到160的维度。
            gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(dtype=mask_feats.dtype) # [160, 1, 200, 304]

            if len(pred_instances) == 0: # 160
                loss_mask = mask_feats.sum() * 0 + pred_instances.mask_head_params.sum() * 0
            else:
                pdb.set_trace()
                mask_scores = self.mask_heads_forward_with_coords( # 调用mask_heads_forward_with_coords 得到mask_scores
                    mask_feats, mask_feat_stride, pred_instances
                )
                mask_losses = dice_coefficient(mask_scores, gt_bitmasks)#[160] 维度的loss
                loss_mask = mask_losses.mean() # 
            pdb.set_trace()
            return loss_mask.float()
        else:
            if len(pred_instances) > 0:
                mask_scores = self.mask_heads_forward_with_coords(
                    mask_feats, mask_feat_stride, pred_instances
                )
                pred_instances.pred_global_masks = mask_scores.float()

            return pred_instances

'''
1. gt_bitmasks
gt_bitmasks = torch.cat([per_im.gt_bitmasks for per_im in gt_instances]) # 循环batchsize次
    (Pdb) gt_instances[0].gt_bitmasks.size()
    torch.Size([15, 200, 304])
    (Pdb) gt_instances[1].gt_bitmasks.size()
    torch.Size([3, 200, 304])

2. gt_bitmasks
 gt_bitmasks = gt_bitmasks[gt_inds].unsqueeze(dim=1).to(dtype=mask_feats.dtype)
 [160, 1, 200, 304]


(Pdb) pred_instances.gt_inds

tensor([ 0,  0,  0,  0,  0,  0,  0,  0,  0,  6,  6,  6,  5,  5,  5,  6,  6,  6,
         5,  5,  5,  6,  6,  6,  9,  9,  9,  8,  8,  8, 12, 12,  5, 10, 10, 10,
        11, 11,  4,  4,  4,  9,  9,  9,  8,  8,  8, 12, 12, 10, 10, 10, 13, 13,
        11, 11,  4,  4,  4,  9,  9,  9,  8,  8,  8, 12, 12, 10, 10, 10, 13, 13,
        11, 11,  4,  4,  4, 17, 17, 17, 17, 17, 17, 17, 17, 17,  1,  1,  1,  3,
         3,  1,  1,  1,  3,  3,  2,  2,  2,  1,  1,  1,  3,  3,  2,  2,  2,  4,
         4,  4,  2,  2,  2,  4,  4,  4, 14, 14, 14, 14, 14, 14, 14, 14, 14, 17,
        17, 17, 15, 15, 15,  2,  2,  7,  7,  7,  7,  7,  7,  7,  7,  7, 14, 14,
        15, 15, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16],
       device='cuda:0')

(Pdb) soi
tensor([  64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,   64.,
          64.,   64.,   64.,   64.,   64.,   64.,  128.,  128.,  128.,  128.,
         128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,
         128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,
         128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,
         128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,  128.,
         128.,  256.,  256.,  256.,  256.,  256.,  256.,  256.,  256.,  256.,
         256.,  256.,  256.,  256.,  256.,  256.,  256.,  256.,  256.,  256.,
         512.,  512.,  512., 1024., 1024., 1024., 1024., 1024., 1024., 1024.],
       device='cuda:0')
(Pdb) soi.size()
torch.Size([160])


(Pdb) instances.fpn_levels
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4], device='cuda:0')


(Pdb) mask_head_inputs.size()
torch.Size([1, 1600, 100, 152])
(Pdb)  self.channels
8
(Pdb)  self.bias_nums
[8, 8, 1]
(Pdb) self.weight_nums
[80, 64, 8]
(Pdb) mask_head_params.size()
torch.Size([160, 169])


parse_dynamic_param()方法

(Pdb) len(weight_splits)
3
(Pdb) weight_splits[0].size()
torch.Size([1280, 10, 1, 1])
(Pdb) weight_splits[1].size()
torch.Size([1280, 8, 1, 1])
(Pdb) weight_splits[2].size()
torch.Size([160, 8, 1, 1])



(Pdb) len(bias_splits)
3
(Pdb) bias_splits[0].size()
torch.Size([1280])
(Pdb) bias_splits[1].size()
torch.Size([1280])
(Pdb) bias_splits[2].size()
torch.Size([160])

'''

4. AdelaiDet/adet/modeling/fcos/fcos_outputs.py中Condinst的top_feat结构

 def losses(self, logits_pred, reg_pred, ctrness_pred, locations, gt_instances, top_feats=None):
        """
        Return the losses from a set of FCOS predictions and their associated ground-truth.

        Returns:
            dict[loss name -> loss value]: A dict mapping from loss name to loss value.
        """
        #losses 调用了 _get_ground_truth函数
        training_targets = self._get_ground_truth(locations, gt_instances)

        # Collect all logits and regression predictions over feature maps
        # and images to arrive at the same shape as the labels and targets
        # The final ordering is L, N, H, W from slowest to fastest axis.

        instances = Instances((0, 0))
        instances.labels = cat([
            # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,)
            x.reshape(-1) for x in training_targets["labels"]
        ], dim=0)
        instances.gt_inds = cat([
            # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,)
            x.reshape(-1) for x in training_targets["target_inds"]
        ], dim=0)
        instances.im_inds = cat([
            x.reshape(-1) for x in training_targets["im_inds"] # 最看下方注释
        ], dim=0)
        instances.reg_targets = cat([
            # Reshape: (N, Hi, Wi, 4) -> (N*Hi*Wi, 4)
            x.reshape(-1, 4) for x in training_targets["reg_targets"]
        ], dim=0,)
        instances.locations = cat([
            x.reshape(-1, 2) for x in training_targets["locations"]
        ], dim=0)
        instances.fpn_levels = cat([
            x.reshape(-1) for x in training_targets["fpn_levels"]
        ], dim=0)

        instances.logits_pred = cat([
            # Reshape: (N, C, Hi, Wi) -> (N, Hi, Wi, C) -> (N*Hi*Wi, C)
            x.permute(0, 2, 3, 1).reshape(-1, self.num_classes) for x in logits_pred
        ], dim=0,)
        instances.reg_pred = cat([
            # Reshape: (N, B, Hi, Wi) -> (N, Hi, Wi, B) -> (N*Hi*Wi, B)
            x.permute(0, 2, 3, 1).reshape(-1, 4) for x in reg_pred
        ], dim=0,)
        instances.ctrness_pred = cat([
            # Reshape: (N, 1, Hi, Wi) -> (N*Hi*Wi,)
            x.permute(0, 2, 3, 1).reshape(-1) for x in ctrness_pred
        ], dim=0,)

        if len(top_feats) > 0: # blendmask 
            instances.top_feats = cat([
                # Reshape: (N, -1, Hi, Wi) -> (N*Hi*Wi, -1)   [784, -1]
                x.permute(0, 2, 3, 1).reshape(-1, x.size(1)) for x in top_feats
            ], dim=0,)\
                
        '''
        in BlendMask:
            top_feats[0].size()
        torch.Size([2, 784, 96, 148])
            top_feats[1].size()
        torch.Size([2, 784, 48, 74])
            top_feats[2].size()
        torch.Size([2, 784, 24, 37])
            top_feats[3].size()
        torch.Size([2, 784, 12, 19])
            top_feats[4].size()
        torch.Size([2, 784, 6, 10])
        '''
        '''
        in CondInst:
        (Pdb) top_feats[0].size()
        torch.Size([2, 169, 100, 152])
        (Pdb) top_feats[1].size()
        torch.Size([2, 169, 50, 76])
        (Pdb) top_feats[2].size()
        torch.Size([2, 169, 25, 38])
        (Pdb) top_feats[3].size()
        torch.Size([2, 169, 13, 19])
        (Pdb) top_feats[4].size()
        torch.Size([2, 169, 7, 10])
            
        '''
        # BlendMask
        # instances.top_feats.size() [37872, 784]  在接下来的fcos_losses(self, instances)函数中会继续筛选,最后只剩下[instances, 784]的大小。
        # 这就是attention的矩阵方法:
        # 每一行有784个特征。784代表又784个channel,而37872代表了hw * batchsize的大小.
        # 说白了就把二维的图像h*w平铺成了1维度hw
        
        # CondInst
        # instances.top_feat.size() torch.Size([40534, 169])
        pdb.set_trace()
        return self.fcos_losses(instances)

猜你喜欢

转载自blog.csdn.net/weixin_43823854/article/details/110674757
今日推荐