yolov7 训练 和 tensorrt 实现

在这里插入图片描述


git:https://github.com/WongKinYiu/yolov7


Train

  1. hyp.scratch.p5.yaml

    1. mosaic : mosaic 数据增强 默认1:< 0.8 : 4张拼接 > 0.8 & < 1: 9张拼接 0: 关闭. 根据训练集选择合适的值

          def __getitem__(self, index):
              index = self.indices[index]  # linear, shuffled, or image_weights
      
              hyp = self.hyp
              mosaic = self.mosaic and random.random() < hyp['mosaic']
              if mosaic:
                  # Load mosaic
                  if random.random() < 0.8:
                      img, labels = load_mosaic(self, index)
                  else:
                      img, labels = load_mosaic9(self, index)
                  shapes = None
      
                  # MixUp https://arxiv.org/pdf/1710.09412.pdf
                  if random.random() < hyp['mixup']:
                      if random.random() < 0.8:
                          img2, labels2 = load_mosaic(self, random.randint(0, len(self.labels) - 1))
                      else:
                          img2, labels2 = load_mosaic9(self, random.randint(0, len(self.labels) - 1))
                      r = np.random.beta(8.0, 8.0)  # mixup ratio, alpha=beta=8.0
                      img = (img * r + img2 * (1 - r)).astype(np.uint8)
                      labels = np.concatenate((labels, labels2), 0)
      
    2. scale : (1 - scale) - (1+ scale): 随机缩放 根据训练集选择合适的值

          # Rotation and Scale
          R = np.eye(3)
          a = random.uniform(-degrees, degrees)
          # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
          s = random.uniform(1 - scale, 1.1 + scale)
          # s = 2 ** random.uniform(-scale, scale)
          R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
    
    1. anchor_t: 选取正样本 长宽比阈值
  2. data/xx.yaml 配置训练集、验证集、测试集、类别、类别名

    train: ./coco/train2017.txt  # 118287 images
    val: ./coco/val2017.txt  # 5000 images
    test: ./coco/test-dev2017.txt  # 20288 of 40670 images, submit to https://competitions.codalab.org/competitions/20794
    
    # number of classes
    nc: 80
    
    # class names
    names: [ "", ""]
    
  3. cfg/trainning/yolov7.yaml 修改nc

  4. train.py 修改对应的cfg,data, batch-size等等

  5. 模型默认保存在 run/train 下


重参化

  1. 训练后的模型,一定要重参, 否则 后期 tensorrt 推理的结果是不正确的。

  2. 官方的模型是重参后的模型,所以当后期使用tensorrt推理时,你会发现使用官方的模型,推理的结果是正确的,使用自己训练的模型,推理的结果就不正确. 我就在这挣扎了好几天,一度让我怀疑人生了。

  3. 参考:

    1. https://www.iotword.com/2642.html
    2. https://github.com/WongKinYiu/yolov7/blob/main/tools/reparameterization.ipynb
    from copy import deepcopy
    import torch.utils.data
    from models.yolo import Model
    from utils.torch_utils import select_device, is_parallel
    
    nc = 80
    anchors = 3
    device = select_device('0', batch_size=1)
    ckpt = torch.load('weights/best.pt', map_location=device)
    # reparameterized model in cfg/deploy/*.yaml
    model = Model('cfg/deploy/yolov7.yaml', ch=3, nc=5).to(device)
    # print(model)
    
    # copy intersect weights
    state_dict = ckpt['model'].float().state_dict()
    exclude = []
    intersect_state_dict = {
          
          k: v for k, v in state_dict.items() if
                            k in model.state_dict() and not any(x in k for x in exclude) and v.shape == model.state_dict()[
                                k].shape}
    model.load_state_dict(intersect_state_dict, strict=False)
    model.names = ckpt['model'].names
    model.nc = ckpt['model'].nc
    
    for i in state_dict:
        print(i)
    # print(intersect_state_dict)
    
    # reparametrized YOLOR  将yolor头部的权重赋值
    for i in range((model.nc + 5) * anchors):
        model.state_dict()['model.105.m.0.weight'].data[i, :, :, :] *= state_dict['model.105.im.0.implicit'].data[:, i,
                                                                       ::].squeeze()
        model.state_dict()['model.105.m.1.weight'].data[i, :, :, :] *= state_dict['model.105.im.1.implicit'].data[:, i,
                                                                       ::].squeeze()
        model.state_dict()['model.105.m.2.weight'].data[i, :, :, :] *= state_dict['model.105.im.2.implicit'].data[:, i,
                                                                       ::].squeeze()
    model.state_dict()['model.105.m.0.bias'].data += state_dict['model.105.m.0.weight'].mul(
        state_dict['model.105.ia.0.implicit']).sum(1).squeeze()
    model.state_dict()['model.105.m.1.bias'].data += state_dict['model.105.m.1.weight'].mul(
        state_dict['model.105.ia.1.implicit']).sum(1).squeeze()
    model.state_dict()['model.105.m.2.bias'].data += state_dict['model.105.m.2.weight'].mul(
        state_dict['model.105.ia.2.implicit']).sum(1).squeeze()
    model.state_dict()['model.105.m.0.bias'].data *= state_dict['model.105.im.0.implicit'].data.squeeze()
    model.state_dict()['model.105.m.1.bias'].data *= state_dict['model.105.im.1.implicit'].data.squeeze()
    model.state_dict()['model.105.m.2.bias'].data *= state_dict['model.105.im.2.implicit'].data.squeeze()
    
    # model to be saved
    ckpt = {
          
          'model': deepcopy(model.module if is_parallel(model) else model).half(),
            'optimizer': None,
            'training_results': None,
            'epoch': -1}
    
    # save reparameterized model
    torch.save(ckpt, 'weights/best_reparam.pt')
    
  4. 关于REP模块

    1. REP模块分为两个,一个是train,也就是训练,一个deploy,也就是推理。
      1. 训练模块,它有三个分支。最上面的分支是3x3的卷积,用于特征提取。中间的分支是1x1的卷积,用于平滑特征。最后分支是一个Identity,不做卷积操作,直接移过来。最后把它们相加在一起。
      2. 推理模块,包含一个3x3的卷积,stride(步长为1)。是由训练模块重参数化转换而来。
        在训练模块中,因为第一层是一个3x3的卷积,第二层是一个1x1的卷积,最后层是一个Identity。在模型从参数化的时候,需要把1x1的卷积啊,转换成3x3的卷积,把Identity也转换成3x3的卷积,然后进行一个矩阵的一个加法,也就是一个矩阵融合过程。
        然后最后将它的权重进行相加,就得到了一个3x3的卷积,也就是说,这三个分支就融合成了一条线,里面只有一个3x3的卷积。它们的权重是三个分支的叠加结果,矩阵,也是三个分支的叠加结果。

    在这里插入图片描述
    2. 详细参考:

    1. https://blog.csdn.net/qq128252/article/details/126673493
    2. https://blog.csdn.net/qq_41580422/article/details/126316738

Tensorrt

  1. tensorrt 实现参考: https://github.com/QIANXUNZDL123/tensorrtx-yolov7
  2. 自己训练的模型,要使用重参后的模型生成wts

END
  1. 主要记录下训练和tensorrt实现过程遇到的一些问题。
  2. 最后感谢各位大佬的杰作,respect.

在这里插入图片描述

猜你喜欢

转载自blog.csdn.net/haiyangyunbao813/article/details/126844146