PyTorch目标检测(八)

上接(七) 在PyTorch上移植 tinyssd做刚性物体的目标检测

自定义损失函数

# 损失函数
cls_loss = nn.CrossEntropyLoss(weight=None, reduction='mean') 类别损失采取交叉熵均值
bbox_loss = nn.L1Loss(reduction='mean') 锚框偏移采取角点偏移量绝对值的均值

def calc_loss(cls_preds, cls_labels, bbox_preds, bbox_labels, bbox_masks): 总损失两者叠加
    loss1 = cls_loss(cls_preds, cls_labels)
    loss2 = bbox_loss(bbox_preds*bbox_masks, bbox_labels*bbox_masks)
    return loss1+loss2

# 评价函数
def cls_eval(cls_preds, cls_labels): 类别预测准确个数
    values, indices = cls_preds.max(-1)
    correct = torch.sum(indices==cls_labels).numpy()
    return correct

def bbox_eval(bbox_preds, bbox_labels, bbox_masks): 角点偏移量
    tem = ((bbox_labels-bbox_preds)*bbox_masks).abs()
    return torch.sum(tem).numpy()

# 锚框产生
def MultiBoxPrior(feature_map, sizes=[0.75, 0.5, 0.25], ratios=[1, 2, 0.5]):
    """
    Args:
        feature_map: torch tensor, Shape: [N, C, H, W].
        sizes: List of sizes (0~1) of generated MultiBoxPriores. 
        ratios: List of aspect ratios (non-negative) of generated MultiBoxPriores. 
    Returns:
        anchors of shape (1, num_anchors, 4). 由于batch里每个都一样, 所以第一维为1
    """
    pairs = [] # pair of (size, sqrt(ration))
    for r in ratios:
        pairs.append([sizes[0], math.sqrt(r)])
    for s in sizes[1:]:
        pairs.append([s, math.sqrt(ratios[0])])
    
    pairs = np.array(pairs)
    
    ss1 = pairs[:, 0] * pairs[:, 1] # size * sqrt(ration)
    ss2 = pairs[:, 0] / pairs[:, 1] # size / sqrt(ration)
    
    base_anchors = np.stack([-ss1, -ss2, ss1, ss2], axis=1) / 2
    
    h, w = feature_map.shape[-2:]
    shifts_x = np.arange(0, w) / w
    shifts_y = np.arange(0, h) / h
    shift_x, shift_y = np.meshgrid(shifts_x, shifts_y)
    shift_x = shift_x.reshape(-1)
    shift_y = shift_y.reshape(-1)
    shifts = np.stack((shift_x, shift_y, shift_x, shift_y), axis=1)
    
    anchors = shifts.reshape((-1, 1, 4)) + base_anchors.reshape((1, -1, 4))
    
    return torch.tensor(anchors, dtype=torch.float32).view(1, -1, 4)

数据读取

import os
import torch
from torchvision import transforms
import torch.utils.data as data
from PIL import Image
import numpy as np
import xml.etree.ElementTree as ET

classname = ['redbox','matrix','bluebox','beer','redbull','ball','AD','milk']
# 读取第i张图片的xml信息和jpg图像
def get_example(self, i):
    id_ = self.ids[i]
    anno = ET.parse(os.path.join(self.root_dir, 'Annotations', id_+'.xml'))
    bbox = []
    label = []
    for obj in anno.findall('object'):
        bndbox_anno = obj.find('bndbox')
        bbox.append([int(bndbox_anno.find(tag).text)-1 for tag in ('ymin', 'xmin', 'ymax', 'xmax')])
        name = obj.find('name').text
        label.append(classname.index(name))
    bbox = np.stack(bbox)
    label = np.stack(label)
    result = np.append(bbox, label)
    return result

def getimg(self, idx):
    id_ = self.ids[idx]
    img = Image.open(os.path.join(self.root_dir, 'JPEGImages',id_+'.jpg'))
    img = self.transform(img)
    return img


# 数据读取
class my_date(data.Dataset):

    def __init__(self, root_dir, name):
        self.root_dir = root_dir
        self.annopath = os.path.join(root_dir,'Annotations')
        self.imgpath = os.path.join(root_dir, 'JPEGImages')
        self.idpath = os.path.join(root_dir, 'ImageSets', name+'.txt')
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
        ])
        self.ids = []
        for line in open(self.idpath):
            self.ids.append(str(line))
     
    def __getitem__(self, idx): 
        return getimg(self, idx), get_example(self, idx)
    
    def __len__(self):
        return len(self.ids)
        
# 测试
data = my_date('./czkdata', 'train')
print(data[0])

这里参考voc2012数据集的目录建立自己的数据集,VOC2012的数据集解析可以参考这篇博客添加链接描述
这里my_data读取之后返回的格式为3维归一化的tensor图像,加上[方框左上角与右下角的坐标,锚框内物体类别]

发布了25 篇原创文章 · 获赞 2 · 访问量 2099

猜你喜欢

转载自blog.csdn.net/weixin_43874764/article/details/104442099