SSD源码阅读一

 个人博客:http://www.chenjianqu.com/

原文链接:http://www.chenjianqu.com/show-91.html

上一篇博客读了SSD的论文<SSD论文笔记>,原作者是在Caffe上实现,但是我对这个框架不太熟悉,因此找大佬们在Pytorch上的实现:https://github.com/amdegroot/ssd.pytorch ,github里面给出了安装和运行的步骤。这篇博客主要是通过阅读该项目的源码,加深对SSD和Pytorch的理解。

    在下载了数据集和vgg权重文件后,可以通过train.py训练SSD模型,因此我从train.py文件开始阅读。开始:::

    train.py里面首先通过参数解析器读取命令行传入的参数,根据是否启用GPU确定默认的张量类型和创建对应的文件夹:

 train.py

def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")
#初始化参数解析器
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training With Pytorch')
#创建一个互斥组,组内参数不可同时出现
train_set = parser.add_mutually_exclusive_group()
#数据集选择
parser.add_argument('--dataset', default='VOC', choices=['VOC', 'COCO'],
                    type=str, help='VOC or COCO')
#数据集路径,VOC_ROOT是VOC数据集的目录,定义在data/voc0712.py里
parser.add_argument('--dataset_root', default=VOC_ROOT,
                    help='Dataset root directory path')
#backbone网络
parser.add_argument('--basenet', default='vgg16_reducedfc.pth',
                    help='Pretrained base model')
#batch_size
parser.add_argument('--batch_size', default=32, type=int,
                    help='Batch size for training')
#恢复训练的权重目录
parser.add_argument('--resume', default=None, type=str,
                    help='Checkpoint state_dict file to resume training from')
#开始迭代的次数
parser.add_argument('--start_iter', default=0, type=int,
                    help='Resume training at this iter')
#数据处理使用的线程数
parser.add_argument('--num_workers', default=4, type=int,
                    help='Number of workers used in dataloading')
#启用GPU训练
parser.add_argument('--cuda', default=True, type=str2bool,
                    help='Use CUDA to train model')
#学习率
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float,
                    help='initial learning rate')
#动量系数
parser.add_argument('--momentum', default=0.9, type=float,
                    help='Momentum value for optim')
#权重衰减
parser.add_argument('--weight_decay', default=5e-4, type=float,
                    help='Weight decay for SGD')
#学习率更新系数
parser.add_argument('--gamma', default=0.1, type=float,
                    help='Gamma update for SGD')
#是否用到visdom
parser.add_argument('--visdom', default=False, type=str2bool,
                    help='Use visdom for loss visualization')
#权重保存路径
parser.add_argument('--save_folder', default='weights/',
                    help='Directory for saving checkpoint models')
#获得参数空间的对象
args = parser.parse_args()
#根据是否启用GPU确定默认的张量类型
if torch.cuda.is_available():
    if args.cuda:
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    if not args.cuda:
        print("WARNING: It looks like you have a CUDA device, but aren't " +
              "using CUDA.\nRun with --cuda for optimal training speed.")
        torch.set_default_tensor_type('torch.FloatTensor')
else:
    torch.set_default_tensor_type('torch.FloatTensor')
#权重文件夹
if not os.path.exists(args.save_folder):
    os.mkdir(args.save_folder)

    接着会执行train()函数。在该函数里面,通过传入的dataset和dataset_root参数判断读取初始化相应的数据集:

 train.py  train()

if args.dataset == 'COCO':
    if args.dataset_root == VOC_ROOT:
        if not os.path.exists(COCO_ROOT):
            parser.error('Must specify dataset_root if specifying dataset')
        print("WARNING: Using default COCO dataset_root because " +
          "--dataset_root was not specified.")
        args.dataset_root = COCO_ROOT
    #读取该数据集的配置,coco定义在data/config里面
    cfg = coco 
    #数据集对象
    dataset = COCODetection(root=args.dataset_root,
                            transform=SSDAugmentation(cfg['min_dim'],
                            MEANS))
elif args.dataset == 'VOC':
    if args.dataset_root == COCO_ROOT:
        parser.error('Must specify dataset if specifying dataset_root')
    cfg = voc
    dataset = VOCDetection(root=args.dataset_root,
                           transform=SSDAugmentation(cfg['min_dim'],
                         MEANS))

    上面代码里面的cocovoc是定义在data/config.py里面的字典,是对应数据集网络配置的参数。

data/config.py

# SSD300 CONFIGS
#VOC数据的网络配置
voc = {
    'num_classes': 21,#类别数
    'lr_steps': (80000, 100000, 120000),#学习率下降的步数
    'max_iter': 120000,#最大迭代次数
    'feature_maps': [38, 19, 10, 5, 3, 1],#输出特征图的尺寸
    'min_dim': 300,#输入图片的短边长
    'steps': [8, 16, 32, 64, 100, 300],#输出特征图每个像素对应到原图的大小,也就是原图到特征图的下采样系数
    'min_sizes': [30, 60, 111, 162, 213, 264],#计算先验框用到的Smin和Smax
    'max_sizes': [60, 111, 162, 213, 264, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],#各输出特征图的长宽比
    'variance': [0.1, 0.2],
    'clip': True,#是否截断先验框在图像边界内,在prior_box里会用到
    'name': 'VOC',#名称
}
coco = {
    'num_classes': 201,
    'lr_steps': (280000, 360000, 400000),
    'max_iter': 400000,
    'feature_maps': [38, 19, 10, 5, 3, 1],
    'min_dim': 300,
    'steps': [8, 16, 32, 64, 100, 300],
    'min_sizes': [21, 45, 99, 153, 207, 261],
    'max_sizes': [45, 99, 153, 207, 261, 315],
    'aspect_ratios': [[2], [2, 3], [2, 3], [2, 3], [2], [2]],
    'variance': [0.1, 0.2],
    'clip': True,
    'name': 'COCO',
}

     在处理任何机器学习 问题之前都需要数据读取, 并进行预处理。 PyTorch 提供了很多工具使得数据的读取和预处理变得很容易 。  torch.utils.data.Dataset 是代表这一数据的抽象类,可以自己定义数据,只需继承和重写这个抽象类,需要定义__len__和__getitem__这两个函数。回到train.py train()里的代码,这里自定义的数据集对象是COCODetection和VOCDetection,分别定义在data/coco.py和voc0712.py里面,后者为例,数据集对象调用如下:

 train.py  train()

dataset = VOCDetection(root=args.dataset_root,
                       transform=SSDAugmentation(cfg['min_dim'],MEANS)
  )

    该类构造函数的第一个参数是VOC数据集的目录,第二个参数是数据增强对象,我们来看一下:SSDAugmentation类定义在augmentations.py里面。

augmentations.py

class SSDAugmentation(object):
#参数:输入分辨率,数据集RGB均值
    def __init__(self, size=300, mean=(104, 117, 123)):
        self.mean = mean
        self.size = size
        #将多个数据变形压缩到一个
        self.augment = Compose([
            ConvertFromInts(),#将整形图像数据转换为float32数据
            ToAbsoluteCoords(),
            PhotometricDistort(),#光度增强
            Expand(self.mean),#将图像随机扩展,拓展的像素值为self.mean
            RandomSampleCrop(),#随机采样裁切
            RandomMirror(),#将图像和先验框随机水平翻转
            ToPercentCoords(),#将先验框的中心和长宽除以图像的宽和高
            Resize(self.size),#将图像数据缩放到self.size*self.size
            SubtractMeans(self.mean) #将图像数据减去均值
        ])
    def __call__(self, img, boxes, labels):
        return self.augment(img, boxes, labels)

    这里面调用的数据增强类都定义在augmentations.py里面,可以看一下Compose这个类,将多个数据增强压缩到一个:

augmentations.py

class Compose(object):
    def __init__(self, transforms):
        self.transforms = transforms
    #该函数将类实例变成可调用对象
    def __call__(self, img, boxes=None, labels=None):
        #依次执行数据变形
        for t in self.transforms:
            img, boxes, labels = t(img, boxes, labels)
        return img, boxes, labels

    回到train.py train()函数里面,VOCDetection定义在voc0712.py里面。如下:

data/voc0712.py

class VOCDetection(data.Dataset):
    def __init__(self, 
                 root,#数据集根目录
                 image_sets=[('2007', 'trainval'), ('2012', 'trainval')],#数据集
                 transform=None, #数据增强
                 target_transform=VOCAnnotationTransform(),#标签数据增强
                 dataset_name='VOC0712'#数据集名称
        ):
        self.root = root
        self.image_set = image_sets
        self.transform = transform
        self.target_transform = target_transform
        self.name = dataset_name
        
        self._annopath = osp.join('%s', 'Annotations', '%s.xml')
        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
        
        self.ids = list() #ids存放所有的图片路径
        for (year, name) in image_sets:
            rootpath = osp.join(self.root, 'VOC' + year)
            for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
                self.ids.append((rootpath, line.strip()))

                
    #需要覆盖的函数,获取索引数据,返回数据集中的第i个样本
    def __getitem__(self, index):
        im, gt, h, w = self.pull_item(index)
        return im, gt
    
    #需要覆盖的函数,返回数据集的大小
    def __len__(self):
        return len(self.ids)
    
    #获取第i个图片
    def pull_item(self, index):
        img_id = self.ids[index]
        #读取xml注释文件
        target = ET.parse(self._annopath % img_id).getroot()
        #读取图像
        img = cv2.imread(self._imgpath % img_id)
        height, width, channels = img.shape
        #xml解析
        if self.target_transform is not None:
            target = self.target_transform(target, width, height)
        #图像数据增强
        if self.transform is not None:
            target = np.array(target)
            img, boxes, labels = self.transform(img, target[:, :4], target[:, 4])
            img = img[:, :, (2, 1, 0)]#将BGR图像转换为RBG图像
            #将gt box和标签融合为一个张量
            target = np.hstack((boxes, np.expand_dims(labels, axis=1)))
        return torch.from_numpy(img).permute(2, 0, 1), target, height, width

#其中xml解析器如下
class VOCAnnotationTransform(object):
    def __init__(self, class_to_ind=None, keep_difficult=False):
        self.class_to_ind = class_to_ind or dict(
            zip(VOC_CLASSES, range(len(VOC_CLASSES))))
        self.keep_difficult = keep_difficult

#参数:xml的内容,图片的宽,图片的高
    def __call__(self, target, width, height):
        """
        Arguments:
            target (annotation) : the target annotation to be made usable
                will be an ET.Element
        Returns:
            a list containing lists of bounding boxes  [bbox coords, class name]
        """
        res = []
	#对于每个<object>
        for obj in target.iter('object'):
	#判断difficult的目标
            difficult = int(obj.find('difficult').text) == 1
            if not self.keep_difficult and difficult:
                continue
	#获取目标<name>
            name = obj.find('name').text.lower().strip()
            bbox = obj.find('bndbox')
			
			bndbox = []
	#获取目标的左上角和右下角的坐标
            pts = ['xmin', 'ymin', 'xmax', 'ymax']
            for i, pt in enumerate(pts):
                cur_pt = int(bbox.find(pt).text) - 1
                #转换为相对值
                cur_pt = cur_pt / width if i % 2 == 0 else cur_pt / height
                bndbox.append(cur_pt)
			
	#获取目标的标签
            label_idx = self.class_to_ind[name]
            bndbox.append(label_idx)
			
            res += [bndbox]  # [xmin, ymin, xmax, ymax, label_ind]

        return res  # [[xmin, ymin, xmax, ymax, label_ind], ... ]

    再回到train.py train()函数,定义数据集之后,再调用build_ssd()这个函数构建SSD网络结构,该函数定义在ssd.py里面,如下:

ssd.py  build_ssd()

def build_ssd(phase, size=300, num_classes=21):
    #测试模式或训练模式
    if phase != "test" and phase != "train":
        print("ERROR: Phase: " + phase + " not recognized")
        return
        
    #输入分辨率只能是300
    if size != 300:
        print("ERROR: You specified size " + repr(size) + ". However, " +
              "currently only SSD300 (size=300) is supported!")
        return
        
    #获取vgg网络,参数:网络配置、输入通道数
    vggnet=vgg(base[str(size)], 3)
    
    #构造额外的层,参数:网络配置、输入通道数
    extras_layers=add_extras(extras[str(size)], 1024)
    
    #将vgg网络和额外构造的网络连接起来,参数:vgg,额外层、网络配置、类别数
    base_, extras_, head_ = multibox(vggnet,extras_layers,mbox[str(size)], num_classes)
    
    #构造SSD网络
    return SSD(phase, size, base_, extras_, head_, num_classes)

    SSD使用VGG作为backbone,详情:<SSD论文笔记>,<VGG16>。VGG16的网络结构如下图的D网络:

1.jpg

2.jpg

    这里使用pytorch预训练的vgg权重,需要看一下vgg的pytorch实现:https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py 。这里的实现如下:   

ssd.py  vgg()

def vgg(cfg, i, batch_norm=False):#cfg是vgg的配置,i是输入通道数
    layers = []
    in_channels = i
#cfg=[64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'C', 512, 512, 512, 'M',512, 512, 512]
#其中M表示最大池化层,C表示池化的天花板模式,即当池化的size=stride但是height/stride不是整数时,在旁边补上-NAN的值
#数字则表示输出通道数
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        elif v == 'C':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2, ceil_mode=True)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    pool5 = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
    
    conv6 = nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6)#dilation表示使用空洞卷积
    conv7 = nn.Conv2d(1024, 1024, kernel_size=1)
    layers += [pool5, conv6,
               nn.ReLU(inplace=True), conv7, nn.ReLU(inplace=True)]
    return layers

这段代码得到的vgg如下:
0 Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 ReLU(inplace)
2 Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3 ReLU(inplace)
4 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
5 Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
6 ReLU(inplace)
7 Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
8 ReLU(inplace)
9 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
10 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
11 ReLU(inplace)
12 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
13 ReLU(inplace)
14 Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
15 ReLU(inplace)
16 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
17 Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
18 ReLU(inplace)
19 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
20 ReLU(inplace)
21 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
22 ReLU(inplace)
23 MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
24 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
25 ReLU(inplace)
26 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
27 ReLU(inplace)
28 Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
29 ReLU(inplace)
30 MaxPool2d(kernel_size=3, stride=1, padding=1, dilation=1, ceil_mode=False)
31 Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(6, 6), dilation=(6, 6))
32 ReLU(inplace)
33 Conv2d(1024, 1024, kernel_size=(1, 1), stride=(1, 1))
34 ReLU(inplace)

   SSD在vgg网络的末端还另外增加几层卷积层,如下图:

3.jpg

    这里实现如下:

ssd.py  add_extras()

#构建额外的层,参数:网络配置、输入通道数
#cfg=[256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256],i=1024
def add_extras(cfg, i, batch_norm=False):
    layers = []
    in_channels = i
    flag = False
    for k, v in enumerate(cfg):
        if in_channels != 'S':
            if v == 'S':
                layers += [nn.Conv2d(in_channels, cfg[k + 1],kernel_size=(1, 3)[flag], stride=2, padding=1)]
            else:
                layers += [nn.Conv2d(in_channels, v, kernel_size=(1, 3)[flag])]
            flag = not flag
        in_channels = v
    return layers

得到的layers如下:
0 Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1))
1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
2 Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1))
3 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
4 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
5 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
6 Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1))
7 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))

    然后SSD再从多个特征图上预测先验框偏移和置信度,代码如下:

ssd.py  multibox()

#cfg=[4, 6, 6, 6, 4, 4]
def multibox(vgg, extra_layers, cfg, num_classes):
    loc_layers = []
    conf_layers = []
    #从上面vgg()的结果可知,21就是指conv4_3卷积,-2指vgg的conv7(fc7)层。
    vgg_source = [21, -2]
    #在vgg的21、-2层增加定位网络(预测先验框偏移)和置信度网络(预测置信度)
    for k, v in enumerate(vgg_source):
        loc_layers += [nn.Conv2d(vgg[v].out_channels,cfg[k] * 4, kernel_size=3, padding=1)] #定位网络
        conf_layers +=[nn.Conv2d(vgg[v].out_channels,cfg[k] * num_classes, kernel_size=3, padding=1)]#置信度网络
    #在额外层的某些层增加定位网络和置信度网络
    #A[1::2]表示A[1,3,5,7,...];enumerate(A,2)表示迭代的i从2开始
    #因此下面的语句表示取额外层的:{
1 Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
3 Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
5 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
7 Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
}   
    #后面会将这些额外层连接到定位网络和置信度网络
    for k, v in enumerate(extra_layers[1::2], 2):
        loc_layers += [nn.Conv2d(v.out_channels, cfg[k]* 4, kernel_size=3, padding=1)]
        conf_layers += [nn.Conv2d(v.out_channels, cfg[k]* num_classes, kernel_size=3, padding=1)]
    
    return vgg, extra_layers, (loc_layers, conf_layers)
    
    
base_, extras_, head_ = multibox(vggnet,extras_layers,mbox['300'], 21)

得到的head_包含定位网络和置信度网络:
定位网络
0 Conv2d(512, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 Conv2d(1024, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
2 Conv2d(512, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3 Conv2d(256, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
4 Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
5 Conv2d(256, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
置信度网络
0 Conv2d(512, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
1 Conv2d(1024, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
2 Conv2d(512, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
3 Conv2d(256, 126, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
4 Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
5 Conv2d(256, 84, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))

    上面这三个函数得到SSD的三部分网络之后,在SSD这个类里面组合起来。

ssd.py

class SSD(nn.Module):
    def __init__(self, phase, size, base, extras, head, num_classes):
        super(SSD, self).__init__()
        self.phase = phase #训练还是测试
        self.num_classes = num_classes #输出类别数
        self.cfg = (coco, voc)[num_classes == 21] #若num_classes == 21,cfg=voc,否则cfg=coco
        #先验框对象,定义在layers/functions/prior_box.py
        self.priorbox = PriorBox(self.cfg) 
        #设置先验框,在计算loss时用到
        self.priors = Variable(self.priorbox.forward(), volatile=True)

        self.size = size
        #构建vgg网络
        self.vgg = nn.ModuleList(base)
        #L2归一化层,用于将vgg conv4_3进行缩放
        self.L2Norm = L2Norm(512, 20) 
        #构建额外层
        self.extras = nn.ModuleList(extras)
        #构建定位网络
        self.loc = nn.ModuleList(head[0])
        #构建置信度网络
        self.conf = nn.ModuleList(head[1])
        if phase == 'test':
            self.softmax = nn.Softmax(dim=-1)
            self.detect = Detect(num_classes, 0, 200, 0.01, 0.45)

    def forward(self, x):#x: input image or batch of images. Shape: [batch,3,300,300].
        sources = list()
        loc = list()
        conf = list()
        
        #应用vgg从第一层到conv4_3 relu层
        for k in range(23):
            x = self.vgg[k](x)
        
        s = self.L2Norm(x)#从con4_3处连接L2归一化层,这点论文中也有
        
        sources.append(s)#保存到sources,下面用到用于连接到定位网络和置信度网络
        #应用vgg从con4_3到fc7
        for k in range(23, len(self.vgg)):
            x = self.vgg[k](x)
        sources.append(x)
        
        #应用额外层
        for k, v in enumerate(self.extras):
            x = F.relu(v(x), inplace=True)
            if k % 2 == 1:
                sources.append(x)
                
        #将定位网络和置信度网络连接到输出特征图上,并进行维度变换
        for (x, l, c) in zip(sources, self.loc, self.conf):
            loc.append(l(x).permute(0, 2, 3, 1).contiguous())
            conf.append(c(x).permute(0, 2, 3, 1).contiguous())
‘’‘
使用transpose或permute进行维度变换后,调用contiguous,然后方可使用view对维度进行变形。
网上有两种说法,一种是维度变换后tensor在内存中不再是连续存储的,而view操作要求连续存储,所以需要contiguous,另一种是说维度变换后的变量是之前变量的浅复制,指向同一区域,即view操作会连带原来的变量一同变形,这是不合法的,
’‘’
        #view函数的作用为重构张量的维度,相当于numpy中resize()的功能
        #因为定位网络和置信度网络都是卷积,得到是三维的特征图,这里将各网络的预测结果展开,并拼接在一起
        #o.size(0)是batch_size
        loc = torch.cat([o.view(o.size(0), -1) for o in loc], 1)
        conf = torch.cat([o.view(o.size(0), -1) for o in conf], 1)
        
        if self.phase == "test":
            output = self.detect(
                loc.view(loc.size(0), -1, 4),                   # loc preds
                self.softmax(conf.view(conf.size(0), -1,
                             self.num_classes)),                # conf preds
                self.priors.type(type(x.data))                  # default boxes
            )
        else:
            output = (
                loc.view(loc.size(0), -1, 4),#最后将定位结果resize成(batch_size,先验框的数量,总的先验框的四个偏移)
                conf.view(conf.size(0), -1, self.num_classes),#将置信度结果resize成(batch_size,先验框的数量,类别数)
                self.priors
            )
        return output

    #权重加载
    def load_weights(self, base_file):
        other, ext = os.path.splitext(base_file)
        if ext == '.pkl' or '.pth':
            print('Loading weights into state dict...')
            self.load_state_dict(torch.load(base_file, map_location=lambda storage, loc: storage))
            print('Finished!')
        else:
            print('Sorry only .pth and .pkl files supported.')

    经过上面这些代码,就构建完成了SSD的网络结构。然后发现PriorBox和L2Norm这两个类,这是啥?篇幅所限,下篇再见。

发布了74 篇原创文章 · 获赞 33 · 访问量 1万+

猜你喜欢

转载自blog.csdn.net/qq_37394634/article/details/104429967
今日推荐