faster rcnn修改demo.py保存网络中间结果

转自 https://blog.csdn.net/u010668907/article/details/51439503

faster rcnn用python版本https://github.com/rbgirshick/py-faster-rcnn

以demo.py中默认网络VGG16.

原本demo.py地址https://github.com/rbgirshick/py-faster-rcnn/blob/master/tools/demo.py

图有点多,贴一个图的本分结果出来:


上图是原图,下面第一张是网络中命名为“conv1_1”的结果图;第二张是命名为“rpn_cls_prob_reshape”的结果图;第三张是“rpnoutput”的结果图

看一下我修改后的代码:

[python]  view plain  copy
  1. #!/usr/bin/env python  
  2.   
  3. # --------------------------------------------------------  
  4. # Faster R-CNN  
  5. # Copyright (c) 2015 Microsoft  
  6. # Licensed under The MIT License [see LICENSE for details]  
  7. # Written by Ross Girshick  
  8. # --------------------------------------------------------  
  9.   
  10. """ 
  11. Demo script showing detections in sample images. 
  12.  
  13. See README.md for installation instructions before running. 
  14. """  
  15.   
  16. import _init_paths  
  17. from fast_rcnn.config import cfg  
  18. from fast_rcnn.test import im_detect  
  19. from fast_rcnn.nms_wrapper import nms  
  20. from utils.timer import Timer  
  21. import matplotlib.pyplot as plt  
  22. import numpy as np  
  23. import scipy.io as sio  
  24. import caffe, os, sys, cv2  
  25. import argparse  
  26. import math  
  27.   
  28. CLASSES = ('__background__',  
  29.            'aeroplane''bicycle''bird''boat',  
  30.            'bottle''bus''car''cat''chair',  
  31.            'cow''diningtable''dog''horse',  
  32.            'motorbike''person''pottedplant',  
  33.            'sheep''sofa''train''tvmonitor')  
  34.   
  35. NETS = {'vgg16': ('VGG16',  
  36.                   'VGG16_faster_rcnn_final.caffemodel'),  
  37.         'zf': ('ZF',  
  38.                   'ZF_faster_rcnn_final.caffemodel')}  
  39.   
  40.   
  41. def vis_detections(im, class_name, dets, thresh=0.5):  
  42.     """Draw detected bounding boxes."""  
  43.     inds = np.where(dets[:, -1] >= thresh)[0]  
  44.     if len(inds) == 0:  
  45.         return  
  46.   
  47.     im = im[:, :, (210)]  
  48.     fig, ax = plt.subplots(figsize=(1212))  
  49.     ax.imshow(im, aspect='equal')  
  50.     for i in inds:  
  51.         bbox = dets[i, :4]  
  52.         score = dets[i, -1]  
  53.   
  54.         ax.add_patch(  
  55.             plt.Rectangle((bbox[0], bbox[1]),  
  56.                           bbox[2] - bbox[0],  
  57.                           bbox[3] - bbox[1], fill=False,  
  58.                           edgecolor='red', linewidth=3.5)  
  59.             )  
  60.         ax.text(bbox[0], bbox[1] - 2,  
  61.                 '{:s} {:.3f}'.format(class_name, score),  
  62.                 bbox=dict(facecolor='blue', alpha=0.5),  
  63.                 fontsize=14, color='white')  
  64.   
  65.     ax.set_title(('{} detections with '  
  66.                   'p({} | box) >= {:.1f}').format(class_name, class_name,  
  67.                                                   thresh),  
  68.                   fontsize=14)  
  69.     plt.axis('off')  
  70.     plt.tight_layout()  
  71.     #plt.draw()  
  72. def save_feature_picture(data, name, image_name=None, padsize = 1, padval = 1):  
  73.     data = data[0]  
  74.     #print "data.shape1: ", data.shape  
  75.     n = int(np.ceil(np.sqrt(data.shape[0])))  
  76.     padding = ((0, n ** 2 - data.shape[0]), (00), (0, padsize)) + ((00),) * (data.ndim - 3)  
  77.     #print "padding: ", padding  
  78.     data = np.pad(data, padding, mode='constant', constant_values=(padval, padval))  
  79.     #print "data.shape2: ", data.shape  
  80.       
  81.     data = data.reshape((n, n) + data.shape[1:]).transpose((0213) + tuple(range(4, data.ndim + 1)))  
  82.     #print "data.shape3: ", data.shape, n  
  83.     data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])  
  84.     #print "data.shape4: ", data.shape  
  85.     plt.figure()  
  86.     plt.imshow(data,cmap='gray')  
  87.     plt.axis('off')  
  88.     #plt.show()  
  89.     if image_name == None:  
  90.         img_path = './data/feature_picture/'   
  91.     else:  
  92.         img_path = './data/feature_picture/' + image_name + "/"  
  93.         check_file(img_path)  
  94.     plt.savefig(img_path + name + ".jpg", dpi = 400, bbox_inches = "tight")  
  95. def check_file(path):  
  96.     if not os.path.exists(path):  
  97.         os.mkdir(path)  
  98. def demo(net, image_name):  
  99.     """Detect object classes in an image using pre-computed object proposals."""  
  100.   
  101.     # Load the demo image  
  102.     im_file = os.path.join(cfg.DATA_DIR, 'demo', image_name)  
  103.     im = cv2.imread(im_file)  
  104.   
  105.     # Detect all object classes and regress object bounds  
  106.     timer = Timer()  
  107.     timer.tic()  
  108.     scores, boxes = im_detect(net, im)  
  109.     for k, v in net.blobs.items():  
  110.         if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1:  
  111.             save_feature_picture(v.data, k.replace("/", ""), image_name)#net.blobs["conv1_1"].data, "conv1_1")   
  112.     timer.toc()  
  113.     print ('Detection took {:.3f}s for '  
  114.            '{:d} object proposals').format(timer.total_time, boxes.shape[0])  
  115.   
  116.     # Visualize detections for each class  
  117.     CONF_THRESH = 0.8  
  118.     NMS_THRESH = 0.3  
  119.     for cls_ind, cls in enumerate(CLASSES[1:]):  
  120.         cls_ind += 1 # because we skipped background  
  121.         cls_boxes = boxes[:, 4*cls_ind:4*(cls_ind + 1)]  
  122.         cls_scores = scores[:, cls_ind]  
  123.         dets = np.hstack((cls_boxes,  
  124.                           cls_scores[:, np.newaxis])).astype(np.float32)  
  125.         keep = nms(dets, NMS_THRESH)  
  126.         dets = dets[keep, :]  
  127.         vis_detections(im, cls, dets, thresh=CONF_THRESH)  
  128.   
  129. def parse_args():  
  130.     """Parse input arguments."""  
  131.     parser = argparse.ArgumentParser(description='Faster R-CNN demo')  
  132.     parser.add_argument('--gpu', dest='gpu_id', help='GPU device id to use [0]',  
  133.                         default=0, type=int)  
  134.     parser.add_argument('--cpu', dest='cpu_mode',  
  135.                         help='Use CPU mode (overrides --gpu)',  
  136.                         action='store_true')  
  137.     parser.add_argument('--net', dest='demo_net', help='Network to use [vgg16]',  
  138.                         choices=NETS.keys(), default='vgg16')  
  139.   
  140.     args = parser.parse_args()  
  141.   
  142.     return args  
  143.   
  144. def print_param(net):  
  145.     for k, v in net.blobs.items():  
  146.     print (k, v.data.shape)  
  147.     print ""  
  148.     for k, v in net.params.items():  
  149.     print (k, v[0].data.shape)    
  150.   
  151. if __name__ == '__main__':  
  152.     cfg.TEST.HAS_RPN = True  # Use RPN for proposals  
  153.   
  154.     args = parse_args()  
  155.   
  156.     prototxt = os.path.join(cfg.MODELS_DIR, NETS[args.demo_net][0],  
  157.                             'faster_rcnn_alt_opt''faster_rcnn_test.pt')  
  158.     #print "prototxt: ", prototxt  
  159.     caffemodel = os.path.join(cfg.DATA_DIR, 'faster_rcnn_models',  
  160.                               NETS[args.demo_net][1])  
  161.   
  162.     if not os.path.isfile(caffemodel):  
  163.         raise IOError(('{:s} not found.\nDid you run ./data/script/'  
  164.                        'fetch_faster_rcnn_models.sh?').format(caffemodel))  
  165.   
  166.     if args.cpu_mode:  
  167.         caffe.set_mode_cpu()  
  168.     else:  
  169.         caffe.set_mode_gpu()  
  170.         caffe.set_device(args.gpu_id)  
  171.         cfg.GPU_ID = args.gpu_id  
  172.     net = caffe.Net(prototxt, caffemodel, caffe.TEST)  
  173.       
  174.     #print_param(net)  
  175.   
  176.     print '\n\nLoaded network {:s}'.format(caffemodel)  
  177.   
  178.     # Warmup on a dummy image  
  179.     im = 128 * np.ones((3005003), dtype=np.uint8)  
  180.     for i in xrange(2):  
  181.         _, _= im_detect(net, im)  
  182.   
  183.     im_names = ['000456.jpg''000542.jpg''001150.jpg',  
  184.                 '001763.jpg''004545.jpg']  
  185.     for im_name in im_names:  
  186.         print '~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~'  
  187.         print 'Demo for data/demo/{}'.format(im_name)  
  188.         demo(net, im_name)  
  189.   
  190.     #plt.show()  
1.在data下手动创建“feature_picture”文件夹就可以替换原来的demo使用了。

2.上面代码主要添加方法是:save_feature_picture,它会对网络测试的某些阶段的数据处理然后保存。

3.某些阶段是因为:if k.find("conv")>-1 or k.find("pool")>-1 or k.find("rpn")>-1这行代码(110行),保证网络层name有这三个词的才会被保存,因为其他层无法用图片

保存,如全连接(参数已经是二维的了)等层。

4.放开174行print_param(net)的注释,就可以看到网络参数的输出。

5.执行的最终结果 是在data/feature_picture产生以图片名字为文件夹名字的文件夹,文件夹下有以网络每层name为名字的图片。

6.另外部分网络的层name中有非法字符不能作为图片名字,我在代码的111行只是把‘字符/’剔除掉了,所以建议网络名字不要又其他字符。

图片下载和代码下载方式:

[plain]  view plain  copy
  1. git clone https://github.com/meihuakaile/faster-rcnn.git  

猜你喜欢

转载自blog.csdn.net/weixin_39970417/article/details/80744825