The last chapter talked about how to make a data set, and then we use mmcls to achieve multi-label classification.
Config configures
mmcls to configure the entire network structure through config. As follows, I am using resnet18, because there are 5 attributes in the data, so the output num_classes=5. It should be noted that the head should use head=dict(type='MultiLabelLinearClsHead'). This is because multi-label classification, before entering loss, should be activated with sigmoid to normalize the value of pred. If softmax is used, there will be a phenomenon of mutual exclusion of attributes (because pred is on dim=1, sum=1). For Multi-label problems, F.binary_cross_entropy_with_logits loss should be used.
model = dict(
type='ImageClassifier',
backbone=dict(
type='ResNet',
depth=18,
num_stages=4,
out_indices=(3, ),
style='pytorch'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
# type='LinearClsHead',
type='MultiLabelLinearClsHead',
num_classes=5,
in_channels=512,
# loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
# topk=(1, 5),
))
In order to read the data and convert the label into a format that can be calculated by the custom dataset, we need to redefine def load_annotations(self): In order not to increase the class, the flag of self.multi_label is defined to separate Multi-label and Multi- class . Our label in txt is a num. For example, if you have 5 attribute categories, the label may be 1, 3, and the format required for the label in BCE is [1,0,1], so we need to convert it.
def load_annotations(self):
"""Load image paths and gt_labels."""
if self.ann_file is None:
samples = self._find_samples()
elif isinstance(self.ann_file, str):
lines = mmcv.list_from_file(
self.ann_file, file_client_args=self.file_client_args)
samples = [x.strip().rsplit(' ', 1) for x in lines]
else:
raise TypeError('ann_file must be a str or None')
data_infos = []
for filename, gt_label in samples:
info = {
'img_prefix': self.data_prefix}
info['img_info'] = {
'filename': filename}
temp_label = np.zeros(len(self.CLASSES))
if not self.multi_label:
info['gt_label'] = np.array(gt_label, dtype=np.int64)
else:
##multi-label classify
if len(gt_label) == 1:
temp_label[np.array(gt_label, dtype=np.int64)] = 1
info['gt_label'] = temp_label
else:
for i in range(np.array(gt_label.split(','), dtype=np.int64).shape[0]):
temp_label[np.array(gt_label.split(','), dtype=np.int64)[i]] = 1
info['gt_label'] = temp_label
data_infos.append(info)
return data_infos
Next, multi-label training can be carried out.