Deep learning verification index calculation Torchmetrics package

TorchMetrics can provide us with a simple, clean, and efficient way to handle validation metrics. TorchMetrics provides many ready-made indicator implementations, such as Accuracy, Dice, F1 Score, Recall, MAE, etc. Almost the most common indicators can be found in it. torchmetrics currently has packaged 80+ task evaluation indicators.
Official Reference Documentation

TorchMetrics installation is also very simple, just need PyPI to install the latest version:

pip install torchmetrics

Basic process introduction

We all use micro-batch training during training, and the same is true for TorchMetrics. After a batch of forward pass is completed, the target value Y and predicted value Y_PRED are passed to the measurement object of torchmetrics. The measurement object will calculate the batch index and Save it (called state internally).

When all the batches are completed (that is, one epoch of training is completed), we can return the final result (this is the result calculated for all batches) from the metric object. Each metric object here is inherited from the metric class, which contains 4 key methods:

  • metric.forward(pred, target) – update the metric state and return the result of the metric computed on the current batch. You can also use metric(pred, target) if you prefer, no difference;
  • metric.update(pred, target)——same as forward, but it will not return the calculation result, which is equivalent to only storing the result in the state. If you do not need the metric results calculated on the current batch, this method is preferred, because it will be fast if it does not calculate the final result;
  • metric.compute( ) – returns the final result computed over all batches. In other words, forward is actually equivalent to update+compute;
  • metric.reset( ) – resets the state so it is ready for the next verification phase.

That is to say: in the current batch of our training, after obtaining the output of the model, we can forward or update (it is recommended to use update). After the batch is complete, call compute to get the final result. Finally, you call reset to reset the state indicators during a validation epoch or when starting a new epoch for training.
For example the following code:

import torch
import torchmetrics
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YourModel().to(device)
metric = torchmetrics.Accuracy()
 
for batch_idx, (data, target) in enumerate(val_dataloader):
   data, target = data.to(device), target.to(device)
   output = model(data)
   # metric on current batch
   batch_acc = metric.update(preds, target)
   print(f"Accuracy on batch {
      
      i}: {
      
      batch_acc}")
 
# metric on all batches using custom accumulation
val_acc = metric.compute()
print(f"Accuracy on all data: {
      
      val_acc}")
 
# Resetting internal state such that metric is ready for new data

MetricCollection

In the example above, a single metric was used for the calculation, but more than one metric may be included in general. Torchmetrics provides MetricCollection that can wrap multiple metrics into a single callable class, and its interface is the same as the basic usage above. This way we don't need to deal with each metric individually.
code show as below:

import torch
from torchmetrics import MetricCollection, Accuracy, Precision, Recall
 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = YourModel().to(device)
# collection of all validation metrics
metric_collection = MetricCollection({
    
    
   'acc': Accuracy(),
   'prec': Precision(num_classes=10, average='macro'),
    'rec': Recall(num_classes=10, average='macro')
})
 
for batch_idx, (data, target) in enumerate(val_dataloader):
   data, target = data.to(device), target.to(device)
   output = model(data)
   batch_metrics = metric_collection.forward(preds, target)
   print(f"Metrics on batch {
      
      i}: {
      
      batch_metrics}")
 
val_metrics = metric_collection.compute()
print(f"Metrics on all data: {
      
      val_metrics}")
metric.reset()

It's also possible to use a list instead of a dictionary, but it's much cleaner to use a dictionary.

custom indicator

Although Torchmetrics contains many common indicators, sometimes we need to define some specific indicators that are not commonly used. We only need to inherit the Metric class and implement the update and computing methods. In addition, we need to use self.add_state(state_name, default) to initialize our objects when the class is initialized.
code show as below:

import torch
import torchmetrics
 
class MyAccuracy(Metric):
   def __init__(self, delta):
       super().__init__()
       # to count the correct predictions
       self.add_state('corrects', default=torch.tensor(0))
       # to count the total predictions
       self.add_state('total', default=torch.tensor(0))
 
   def update(self, preds, target):
       # update correct predictions count
       self.correct += torch.sum(preds == target)
       # update total count, numel() returns the total number of elements
       self.total += target.numel()
 
   def compute(self):
       # final computation
       return self.correct / self.total

Guess you like

Origin blog.csdn.net/Joker00007/article/details/127128885