mmcls多标签分类实战(二):resnet多标签分类

上一章讲了如何制作数据集,接下来我们使用mmcls来实现多标签分类。

Config配置
mmcls是通过config来配置整个网络结构的。如下,我使用的是resnet18,因为数据中有5个属性,所以输出的num_classes=5。需要注意的是,head要选用head=dict(type=‘MultiLabelLinearClsHead’)。这是因为多标签分类,在进入loss前,应该用sigmoid激活,将pred的值归一化。如果使用softmax,会出现属性互斥的现象(因为pred在dim=1上,sum=1)。对于Multi-label问题,应该使用F.binary_cross_entropy_with_logits损失。

model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        # type='LinearClsHead',
        type='MultiLabelLinearClsHead',
        num_classes=5,
        in_channels=512,
        # loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
        # topk=(1, 5),
    ))

自定义dataset
为了读取数据,并将label转变为loss可以计算的格式,我们需要重新定 def load_annotations(self):为了不增加类,定义了self.multi_label的flag来分离Multi-label与Multi-class。我们在txt中的label是一个num,比如你有5个属性类别,label可能是1,3,而BCE中label需要的格式是[1,0,1],因此我们需要转化一下。

def load_annotations(self):
        """Load image paths and gt_labels."""
        if self.ann_file is None:
            samples = self._find_samples()
        elif isinstance(self.ann_file, str):
            lines = mmcv.list_from_file(
                self.ann_file, file_client_args=self.file_client_args)
            samples = [x.strip().rsplit(' ', 1) for x in lines]
        else:
            raise TypeError('ann_file must be a str or None')

        data_infos = []
        for filename, gt_label in samples:
            info = {
    
    'img_prefix': self.data_prefix}
            info['img_info'] = {
    
    'filename': filename}
            temp_label = np.zeros(len(self.CLASSES))
            
            if not self.multi_label:
                info['gt_label'] = np.array(gt_label, dtype=np.int64)
            else:
                ##multi-label classify
                if len(gt_label) == 1:
                    temp_label[np.array(gt_label, dtype=np.int64)] = 1
                    info['gt_label'] = temp_label
                else:
                    for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):
                        temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1
                    info['gt_label'] = temp_label
            
            data_infos.append(info)
        return data_infos

接下来就可以进行多标签的训练了。

猜你喜欢

转载自blog.csdn.net/litt1e/article/details/125316552