The difference between multi-task and multi-label
Define multi-task network structure under MXnet
Define Multi-task evaluation metric
Single-task accuracy function:
The official multi_accuracy function for image multi-label classification:
write in front
This series of blogs records the author's entire process of getting started with MXNet. The author mainly used keras and a little bit of tensorflow before coming into contact with MXNet, so he had some deep learning project foundation before getting started with MXNet. The main reference material is the MXNet official tutorial, and I also read some valuable blogs.
The structure of the blog is: first list the author’s expected goals for this stage, and the notes taken during the completion of each goal (only write down the ones that I think are important), and then attach my own questions during the learning process (solved & unsolved, wild questions , welcome to discuss).
Objectives of this stage
Task | priority | Expected to take time | finished condition | Encounter problems | Replenish |
Define Multi-task data format | P0 | 2hour | |||
Define Multi-task network | P1 | 0.5hour | |||
Define Multi-task evaluation metric | P2 | 1.5hour | |||
Network training and evaluation |
specific notes
The difference between multi-task and multi-label
-
Multi-task is more complicated than multi-label, and the intermediate process of the network can have branches
-
multi-label is a special multi-task. When the classification value of each task is two classifications, it is multi-label, but each task of multi-task can be multi-classification
Define multi-task network structure under MXnet
- Code:
- Illustration:
Define Multi-task evaluation metric
There is a lot of metric information about multi-task on the Internet, but they are basically multi_accuracy. Here we have compiled the single (multi-) task versions of accuracy / cross-entropy / precision / recall.
-
Single-task accuracy function:
import mxnet as mx
class Accuracy(mx.metric.EvalMetric):
def __init__(self, num=None):
super(Accuracy, self).__init__('accuracy', num)
def update(self, labels, preds):
pred_label = mx.nd.argmax_channel(preds[0]).asnumpy().astype('int32')
label = labels[0].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
-
The official multi_accuracy function for image multi-label classification:
class Multi_Accuracy(mx.metric.EvalMetric):
"""Calculate accuracies of multi label"""
def __init__(self, num=None):
self.num = num
super(Multi_Accuracy, self).__init__('multi-accuracy')
def reset(self):
"""Resets the internal evaluation result to initial state."""
self.num_inst = 0 if self.num is None else [0] * self.num
self.sum_metric = 0.0 if self.num is None else [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
if self.num is not None:
assert len(labels) == self.num
for i in range(len(labels)):
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32')
label = labels[i].asnumpy().astype('int32')
mx.metric.check_label_shapes(label, pred_label)
if self.num is None:
self.sum_metric += (pred_label.flat == label.flat).sum()
self.num_inst += len(pred_label.flat)
else:
self.sum_metric[i] += (pred_label.flat == label.flat).sum()
self.num_inst[i] += len(pred_label.flat)
def get(self):
"""Gets the current evaluation result.
Returns
-------
names : list of str
Name of the metrics.
values : list of float
Value of the evaluations.
"""
if self.num is None:
return super(Multi_Accuracy, self).get()
else:
return zip(*(('%s-task%d'%(self.name, i), float('nan') if self.num_inst[i] == 0 else self.sum_metric[i] / self.num_inst[i]) for i in range(self.num)))
def get_name_value(self):
"""Returns zipped name and value pairs.
Returns
-------
list of tuples
A (name, value) tuple list.
"""
if self.num is None:
return super(Multi_Accuracy, self).get_name_value()
name, value = self.get()
return list(zip(name, value))
When calling, modify the parameter num of Multi_Accuracy(num=3), and you can specify how many accuracy to calculate.
from my_metric import *
eval_metric = mx.metric.CompositeEvalMetric()
eval_metric.add(Multi_Accuracy(num=2))
The following cross-entropy/recall and precision metric functions can be specified for single task or multi-task by modifying num and name.
-
cross-rentropy
class CrossEntropy(mx.metric.EvalMetric):
def __init__(self, eps=1e-12, name='cross-entropy',
output_names=None, label_names=None, num=None):
super(CrossEntropy, self).__init__(
name, eps=eps,
output_names=output_names, label_names=label_names)
self.eps = eps
self.num = num
self.name = name
self.reset()
def reset(self):
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
i = 0
for label, pred in zip(labels, preds):
label = label.asnumpy()
pred = pred.asnumpy()
label = label.ravel()
assert label.shape[0] == pred.shape[0]
if i == 1:
sexy_index = np.where(np.int64(label) == -1)
label[sexy_index] = 0.0 # random 0 or 1
pred[sexy_index] = np.ones((len(sexy_index),2)) # No loss for sexy image
prob = pred[np.arange(label.shape[0]), np.int64(label)]
if self.num is None:
self.sum_metric += (-np.log(prob + self.eps)).sum()
if i == 1:
self.num_inst += (label.shape[0] - len(sexy_index[0]))
else:
self.num_inst += label.shape[0]
else:
self.sum_metric[i] += (-np.log(prob + self.eps)).sum()
if i == 1:
self.num_inst[i] += (label.shape[0] - len(sexy_index[0]))
else:
self.num_inst[i] += label.shape[0]
i += 1
def get(self):
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
return (self.name, result)
-
recall
class Recall(mx.metric.EvalMetric):
def __init__(self, name, num=None):
super(Recall, self).__init__('Recall')
self.num = num
self.name = name
self.reset()
def reset(self):
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
i = 0
for pred, label in zip(preds, labels):
pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
count_pred = 0
count_truth = 0
for index in range(len(pred.flat)):
if label[index] == -1:
continue
if pred[index] == 0 and label[index] == 0:
count_pred += 1
if label[index] == 0:
count_truth += 1
if self.num is None:
self.sum_metric += count_pred
self.num_inst += count_truth
else:
self.sum_metric[i] += count_pred
self.num_inst[i] += count_truth
i += 1
def get(self):
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
return (self.name, result)
-
precision
class Precision(mx.metric.EvalMetric):
def __init__(self, name, num=None):
super(Precision, self).__init__('Precision')
self.num = num
self.name = name
self.reset()
def reset(self):
if getattr(self, 'num', None) is None:
self.num_inst = 0
self.sum_metric = 0.0
else:
self.num_inst = [0] * self.num
self.sum_metric = [0.0] * self.num
def update(self, labels, preds):
mx.metric.check_label_shapes(labels, preds)
i = 0
for pred, label in zip(preds, labels):
pred = mx.nd.argmax_channel(pred).asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
count_truth = 0
count_pred = 0
for index in range(len(pred.flat)):
if label[index] == -1:
continue
if pred[index] == 0 and label[index] == 0:
count_truth +=1
if pred[index] ==0:
count_pred +=1
if self.num is None:
self.sum_metric += count_truth
self.num_inst += count_pred
else:
self.sum_metric[i] += count_truth
self.num_inst[i] += count_pred
i += 1
def get(self):
if self.num is None:
if self.num_inst == 0:
return (self.name, float('nan'))
else:
return (self.name, self.sum_metric / self.num_inst)
else:
result = [sum / num if num != 0 else float('nan') for sum, num in zip(self.sum_metric, self.num_inst)]
return (self.name, result)