(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

我的任务是在行人头肩数据上训练并测试centernet网络,先证明一下我是真的训练了哈,这是用centernet检测的一张结果(训练了10个epochs的结果,大家放心使用,网络功能还是很强大的):

我参考的这篇博客,对我自己的实验帮助很大:https://blog.csdn.net/weixin_42634342/article/details/97756458

论文作者代码https://github.com/xingyizhou/CenterNet

这个博客是整个训练的过程,可能会有点长。

1. 准备数据集

    0.(我用的数据是VOC格式的,需要将其转化为COCO格式)

    详细过程在我的另一个博客里,本来想写在这里,发现太长了,就移到另一个里了。

    链接:https://blog.csdn.net/weixin_41765699/article/details/100124689

    1. 当我们生成三个json文件之后,来到CenterNet这个工程里,在data文件夹下新建一个文件夹,名字就是你数据集的名字,如下图:

     再在这个文件夹里面建两个文件夹(annotations里面存放的是我们之前生成的那三个json文件;images存放的是所有的图片,包括训练测试验证三个,所有的):

扫描二维码关注公众号,回复: 11092374 查看本文章

    2. 在src/lib/datasets/dataset里面新建一个“ped. py”,文件内容照着文件夹下coco.py改成自己的

       0. 将COCO类改成自己的名字

       

       1. 第14行num_classes=80改成自己的类别数
       2. 第15行default_resolution(这个参数有两种(300,300)或者(512,512),很明显512的参数计算量大,300计算量小,我用的是512,之后打算训练一个300的对比一下)
       3. 接下来的mean和std改成自己图片数据集的均值和方差,脚本链接:                            https://blog.csdn.net/weixin_41765699/article/details/100118660

       4. 修改数据和图片路径,data_dir 输入的是咱们之前建立的数据集文件夹的名字,img_dir 输入的是 images 图片文件夹

       5. 修改json文件路径如下:

       6.  类别名字和类别id改成自己

       我就改了以上六点内容。

    3. 将数据集加入src/lib/datasets/dataset_factory里面

       1. 在dataset_facto字典里加入自己的数据集名字 (格式为   '你之前创建的Python文件的名字':你自己数据集类的名字,因为要从你创建的py文件里找到你的数据类,名字必须对应上)

    4. 修改/src/lib/opts.py

         1.第一步,将自己的数据集设为默认数据集,加入到help里面


  
  
  1. self.parser.add_argument( '--dataset', default= 'ped',  
  2.                                  help= 'coco | kitti | coco_hp | pascal | ped)

         2.修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

   

         3. 修改src/lib/utils/debugger.py文件(变成自己数据的类别和名字,前后数据集名字一定保持一致)

       

            再加上自己数据的类别,不包括背景__background__

            

      到这里,准备数据集的工作就告一段落了!

2. 搭建pytorch0.4.1+cuda90+cudnn7.6.1(版本不能改,还有就是numpy的版本必须在1.13以上,建议装最新的)

     我搭建这个环境也费老大劲了,pytorch1.2貌似直接pip安装就自动装上了cuda和cudnn,0.4.1版本的我没看见有自动安装的,所以就苦哈哈自己动手装了,关于这个,我也记录了一下,大家也可以自己上网查查别的方法

cuda和cudnn安装链接:https://blog.csdn.net/weixin_41765699/article/details/99966617

torch0.4.1安装链接:https://blog.csdn.net/weixin_41765699/article/details/99756697

3. 克隆工程并运行demo

    关于工程里面这个作者写的很详细了,我是按照一步步来的,没有出错。https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md

    程序里面在运行demo.py之前,会下载一个预训练权重(比如dla34,resnet18,resnet101之类的),这个不用管,等他下载完,因为我们训练的时候也要用的。(下载的时候可能会很慢,如果是在龟速的话,将他下载的网址用QQ浏览器打开自己下载,下载完放到这个它自动创建的文件夹里就可以了,QQ浏览器下载确实比其他的稍快一些)

    

   改完之后在MODEL_ZOO.md里面下载参数,ctdet_coco_dla_2x,下载完毕后放在models文件夹里面。
到这里,环境基本搭建成功,接下来可以跑一下代码了

(模型文件下载貌似要Google drive,这是我下载的:

链接:https://pan.baidu.com/s/1QOmIwy8lXJBuLv5hH5j3ag 
提取码:vwk4 )

    运行demo.py

python demo.py ctdet --demo /home/CenterNet/images/ --load_model /home/CenterNet/models/ctdet_coco_dla_2x.pth

嘿嘿嘿

     要注意的是,当弹出第一站图片的时候,按esc除外的任意键可以继续检测下一张图,想要保存检测结果的话,只需要在src/lib/detectors/cdet.py文件中:


  
  
  1.     def show_results(self, debugger, image, results):   # demo文件會調用這個函數,本函`python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1`數是demo時顯示圖片並保存圖片
  2.         debugger.add_img(image, img_id= 'ctdet')
  3.         for j in range( 1, self.num_classes + 1):
  4.             for bbox in results[j]:
  5.                 if bbox[ 4] > self.opt. vis_thresh:
  6.                     debugger.add_coco_bbox(bbox[ : 4], j - 1, bbox[ 4], img_id= 'ctdet')
  7.         debugger.show_all_imgs(pause= self.pause)
  8.         debugger.save_all_imgs(path= '/home/czb/CenterNet-master/output/', genID=True)

      加上一行代码,就是最后一行debugger.save_all_imgs(path='/home/CenterNet/output/', genID=True) ,path是输出路径,需要在CenterNet文件夹下新建一个文件夹output,然后再运行一遍发现检测后的图片就会保存在这个文件夹里面了。当然,去掉倒数第二行show_all_imgs,那么运行的时候就不会弹出照片了。

4. 训练阶段

     1. 定位一下发现前面自己建立的ped.py文件(修改下面的代码):


  
  
  1.   if split == 'val':
  2.             self.annot_path = os.path.join(
  3.                 self.data_dir, 'annotations',
  4.                 'val.json').format(split) # 修改test的json文件位置
  5.         else:
  6.             if opt.task == 'exdet':
  7.                 self.annot_path = os.path.join(
  8.                     self.data_dir, 'annotations',
  9.                     'train.json').format(split)
  10.             else:
  11.                 self.annot_path = os.path.join(
  12.                     self.data_dir, 'annotations',
  13.                     'train.json').format(split) # 这才是train文件

     要把第一行if split 改为 ==‘val’,这样validate时就会调用val.json文件。把最后一行要调用的文件改为‘train.json’,这样训练的时候才会调用train.json文件。改完之后数据集导入就正常了。
     2. 运行main.py

python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1
(如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size改成16或者更小

   

     这时候会下载一个预训练模型,可能会很慢,我是。。下载的,这是百度盘链接,需要的可以用:

    链接:https://pan.baidu.com/s/1I1oW_l2Xe2-LV1gIjViPTg 
    提取码:2pt0 

     下载完之后放在/root/.torch/models里面

(我的是在这个里,你也可以看看他自动下载的那个在哪个文件夹里,然后把权重放在那个文件夹下)

    没有意外的话,经过上面的步骤,就开始训练了::::

5. 测试部分

     当训练完之后(我训练了两天,泰坦X,140个epochs,有点憨批,其实最好的模型是在第55个epoch出现的),在./exp/ctdet/coco_dla/文件夹下会出现如下文件

     其中,model_last是最后一次epoch的模型;model_best是val最好的模型,我选的是model_best模型;

然后开始测试。。。。。。

   1. 在我们之前建立的ped.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件       test.json

   2. 运行test.py

       python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /root/CenterNet/exp/ctdet/coco_dla/model_best.pth

   不出意外的话会出现下面的画面(出现一系列AP值),其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。

  完事了。。。

2019.9.6

附加1:

我想换个骨干网络试试,作者的源代码支持resnet和hourglass,我尝试替换成resnet18,记录一下替换方法:

在原来的训练命令行命令里添加上两个参数:(顺便把exp_id 改一下,保证每个模型不乱)

python main.py ctdet --exp_id coco_res_18 --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1 --arch res_18 --head_conv 64
  
  

开始训练时也会下载相应的预训练模型,如果下载速度慢,也参照上面说的方法下载。

训练之后,在测试和运行demo的命令行代码里也要加上两个参数:--arch res_18 --head_conv 64

附加2:

训练完成的时候,我们需要绘制出loss值得曲线,以下代码可以实现该功能:

训练生成的日志文件一般在exp/ctdet/../../logs.txt,找到这个文件,打开之后会出现如下:

我们需要读取这些loss值并可视化(一般情况下,该代码只需要改变日志文件的路径即可):


  
  
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. def plot_loss_curve(log_file):
  4. loss_data = open(log_file)
  5. all_lines = loss_data.readlines()
  6. print(all_lines[ 4].split( ' '))
  7. # losses
  8. total_loss = [] # 4
  9. hm_loss = [] # 7
  10. wh_loss = [] # 10
  11. off_loss = [] # 13
  12. val_loss = [] # 19
  13. spend_time = [] # 16
  14. num_lines = len(all_lines)
  15. for line in range(num_lines):
  16. total_loss1 = all_lines[line].split( ' ')[ 4]
  17. hm_loss1 = all_lines[line].split( ' ')[ 7]
  18. wh_loss1 = all_lines[line].split( ' ')[ 10]
  19. off_loss1 = all_lines[line].split( ' ')[ 13]
  20. spend_time1 = all_lines[line].split( ' ')[ 16]
  21. total_loss. append(float(total_loss1))
  22. hm_loss. append(float(hm_loss1))
  23. wh_loss. append(float(wh_loss1))
  24. off_loss. append(float(off_loss1))
  25. spend_time. append(float(spend_time1))
  26. index_val = np.linspace( 0, 140, 29) - 1
  27. index_val = np. delete(index_val, 0, 0)
  28. for id in index_val:
  29. val_loss1 = all_lines[ int(id)].split( ' ')[ 19]
  30. val_loss. append(float(val_loss1))
  31. return val_loss, total_loss
  32. if __name__ == '__main__':
  33. # 标准图形绘制
  34. # sns.set()
  35. vloss_res18, loss_res18 = plot_loss_curve( 'logres18.txt') # 读取训练时生成的日志文件
  36. # vloss_resdcn18, loss_resdcn18 = plot_loss_curve( 'logresdcn18.txt')
  37. # vloss_dla, loss_dla = plot_loss_curve( 'logdla34.txt')
  38. # vloss_res101, loss_res101 = plot_loss_curve( 'logres101.txt')
  39. # vloss_dla34p, loss_dla34p = plot_loss_curve( 'logdla34p.txt')
  40. # vloss_hg, loss_hg = plot_loss_curve( 'loghourglass.txt')
  41. fig = plt.figure(figsize=( 10, 4))
  42. ax = fig.add_subplot( 111)
  43. ax.plot( range( len(loss_res18)), loss_res18, 'c', label= 'res_18_train_loss', linewidth= 1) # 这个label是图线自己的标签;
  44. # ax.plot( range( len(loss_resdcn18)), loss_resdcn18, 'y', label= 'resdcn_18_train_loss', linewidth= 1)
  45. # ax.plot( range( len(loss_dla)), loss_dla, 'b', label= 'dla_34_train_loss', linewidth= 1)
  46. # ax.plot( range( len(loss_res101)), loss_res101, 'g', label= 'res_101_train_loss', linewidth= 1)
  47. # ax.plot( range( len(loss_dla34p)), loss_dla34p, 'r', label= 'dla_34_train_loss', linewidth= 1)
  48. # ax.plot( range( len(loss_hg)), loss_hg, 'r', label= 'hourglass_train_loss', linewidth= 1)
  49. # ax.plot(index_val+ 1, val_loss, label= 'val_loss')
  50. # ax.plot(np.random.randn( 1000).cumsum(), label= 'line2')
  51. # ax.set_xlim([ 0, 800]) # 设置刻度;
  52. # ax.set_xticks( range( 0, 500, 100)) # 设置显示的刻度;
  53. # ax.set_yticklabels([ 'jan', 'feb', 'mar']) # 设置刻度标签;
  54. ax.set_xlabel( 'epochs') # 设置坐标轴标签;
  55. ax.set_ylabel( 'loss_value')
  56. # ax.text( 8750, 20, "海拔", color= 'red') # 加入文本
  57. ax.set_title( 'loss_of_CenterNet')
  58. ax.legend(loc= 'best') # 将图例摆放在不遮挡图线的位置即可
  59. ax.grid() # 添加网格
  60. plt.savefig( 'loss_of_CenterNet.png') # 保存文件到指定文件夹
  61. plt.show()
  62. fig1 = plt.figure(figsize=( 12, 8))
  63. ax1 = fig1.add_subplot( 111)
  64. ax1.plot( range( len(vloss_res18)), vloss_res18, 'c', label= 'res_18_val_loss', linewidth= 2) # 这个label是图线自己的标签;
  65. # ax1.plot( range( len(vloss_resdcn18)), vloss_resdcn18, 'y', label= 'resdcn_18_val_loss', linewidth= 2)
  66. # ax1.plot( range( len(vloss_dla)), vloss_dla, 'b', label= 'dla_34_val_loss', linewidth= 2)
  67. # ax1.plot( range( len(vloss_res101)), vloss_res101, 'g', label= 'res_101_val_loss', linewidth= 2)
  68. # ax1.plot( range( len(vloss_dla34p)), vloss_dla34p, 'r', label= 'dla_34_val_loss_p', linewidth= 2)
  69. # ax.plot(index_val+ 1, val_loss, label= 'val_loss')
  70. # ax.plot(np.random.randn( 1000).cumsum(), label= 'line2')
  71. # ax.set_xlim([ 0, 800]) # 设置刻度;
  72. # ax.set_xticks( range( 0, 500, 100)) # 设置显示的刻度;
  73. # ax.set_yticklabels([ 'jan', 'feb', 'mar']) # 设置刻度标签;
  74. ax1.set_xlabel( 'epochs') # 设置坐标轴标签;
  75. ax1.set_ylabel( 'loss_value')
  76. # ax.text( 8750, 20, "海拔", color= 'red') # 加入文本
  77. ax1.set_title( 'val_loss_of_CenterNet')
  78. ax1.legend(loc= 'best') # 将图例摆放在不遮挡图线的位置即可
  79. ax1.grid() # 添加网格
  80. plt.savefig( 'val_loss_of_CenterNet.png') # 保存文件到指定文件夹
  81. plt.show()
        <div class="person-messagebox">
            <div class="left-message"><a href="https://blog.csdn.net/weixin_41765699">
                <img src="https://profile.csdnimg.cn/1/2/9/3_weixin_41765699" class="avatar_pic" username="weixin_41765699">
            </a></div>
            <div class="middle-message">
                                    <div class="title"><span class="tit "><a href="https://blog.csdn.net/weixin_41765699" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;,&quot;ab&quot;:&quot;new&quot;}" target="_blank">linbior</a></span>
                    <!-- 等级,level -->
                                            <img class="identity-icon" src="https://csdnimg.cn/identity/blog5.png">                                            </div>
                <div class="text"><span>原创文章 46</span><span>获赞 80</span><span>访问量 14万+</span></div>
            </div>
                            <div class="right-message">
                                        <a class="btn btn-sm  bt-button personal-watch" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;,&quot;ab&quot;:&quot;new&quot;}">关注</a>
                                                            <a href="https://im.csdn.net/im/main.html?userName=weixin_41765699" target="_blank" class="btn btn-sm bt-button personal-letter">私信
                    </a>
                                </div>
                        </div>
                    
    </div>
发布了0 篇原创文章 · 获赞 2 · 访问量 1017

我的任务是在行人头肩数据上训练并测试centernet网络,先证明一下我是真的训练了哈,这是用centernet检测的一张结果(训练了10个epochs的结果,大家放心使用,网络功能还是很强大的):

猜你喜欢

转载自blog.csdn.net/yhl41001/article/details/105722995
今日推荐