yolo测试代码梳理


[html] view plain copy
  1. <span style="font-size:18px;">darknet.c  
  2. 首先就是各种头文件  
  3. /////////////////////////////////////  
  4. #include <time.h>  
  5. #include <stdlib.h>  
  6. #include <stdio.h>  
  7.   
  8. #include "parser.h"  
  9. #include "utils.h"  
  10. #include "cuda.h"  
  11. #include "blas.h"  
  12. #include "connected_layer.h"  
  13.   
  14. #ifdef OPENCV  
  15. #include "opencv2/highgui/highgui_c.h"  
  16. #endif  
  17.   
  18. /////////////////////////////////////////////  
  19. 各种接口函数的声明,方便下面调用  
  20.   
  21. void change_rate(char *filename, float scale, float add)  
  22. void average(int argc, char *argv[])  
  23. void speed(char *cfgfile, int tics)  
  24. void operations(char *cfgfile)  
  25. 。。。。。  
  26. /////////////////////////////////////////////////  
  27. 最主要的main函数  
  28.   
  29. int main(int argc, char **argv)  
  30. {  
  31.     if(argc < 2){  
  32.         fprintf(stderr, "usage: %s <function>\n", argv[0]);  
  33.         return 0;  
  34.     }//参数小于2直接输出提示  
  35.     gpu_index = find_int_arg(argc, argv, "-i", 0);  
  36.     if(find_arg(argc, argv, "-nogpu")) {  
  37.         gpu_index = -1;  
  38.     }//设置无gpu格式  
  39.   
  40. #ifndef GPU  
  41.     gpu_index = -1;  
  42. #else  
  43.     if(gpu_index >= 0){  
  44.         cuda_set_device(gpu_index);  
  45.     }  
  46. #endif  
  47. //输入选项  
  48.     if (0 == strcmp(argv[1], "average")){  
  49.         average(argc, argv);  
  50.     } else if (0 == strcmp(argv[1], "yolo")){  
  51.         run_yolo(argc, argv);//从这里跳转出去,执行yolo----------------  
  52.           
  53.     } else if (0 == strcmp(argv[1], "voxel")){  
  54.         run_voxel(argc, argv);  
  55.     } else if (0 == strcmp(argv[1], "super")){  
  56.         run_super(argc, argv);  
  57.     } else if (0 == strcmp(argv[1], "detector")){  
  58.         run_detector(argc, argv);  
  59.     } else if (0 == strcmp(argv[1], "cifar")){  
  60.         run_cifar(argc, argv);  
  61.     } else if (0 == strcmp(argv[1], "go")){  
  62.         run_go(argc, argv);  
  63.     } else if (0 == strcmp(argv[1], "rnn")){  
  64.         run_char_rnn(argc, argv);  
  65.     } else if (0 == strcmp(argv[1], "vid")){  
  66.         run_vid_rnn(argc, argv);  
  67.     } else if (0 == strcmp(argv[1], "coco")){  
  68.         run_coco(argc, argv);  
  69.     } else if (0 == strcmp(argv[1], "classifier")){  
  70.         run_classifier(argc, argv);  
  71.     } else if (0 == strcmp(argv[1], "art")){  
  72.         run_art(argc, argv);  
  73.     } else if (0 == strcmp(argv[1], "tag")){  
  74.         run_tag(argc, argv);  
  75.     } else if (0 == strcmp(argv[1], "compare")){  
  76.         run_compare(argc, argv);  
  77.     } else if (0 == strcmp(argv[1], "dice")){  
  78.         run_dice(argc, argv);  
  79.     } else if (0 == strcmp(argv[1], "writing")){  
  80.         run_writing(argc, argv);  
  81.     } else if (0 == strcmp(argv[1], "3d")){  
  82.         composite_3d(argv[2], argv[3], argv[4], (argc > 5) ? atof(argv[5]) : 0);  
  83.     } else if (0 == strcmp(argv[1], "test")){  
  84.         test_resize(argv[2]);  
  85.     } else if (0 == strcmp(argv[1], "captcha")){  
  86.         run_captcha(argc, argv);  
  87.     } else if (0 == strcmp(argv[1], "nightmare")){  
  88.         run_nightmare(argc, argv);  
  89.     } else if (0 == strcmp(argv[1], "change")){  
  90.         change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);  
  91.     } else if (0 == strcmp(argv[1], "rgbgr")){  
  92.         rgbgr_net(argv[2], argv[3], argv[4]);  
  93.     } else if (0 == strcmp(argv[1], "reset")){  
  94.         reset_normalize_net(argv[2], argv[3], argv[4]);  
  95.     } else if (0 == strcmp(argv[1], "denormalize")){  
  96.         denormalize_net(argv[2], argv[3], argv[4]);  
  97.     } else if (0 == strcmp(argv[1], "statistics")){  
  98.         statistics_net(argv[2], argv[3]);  
  99.     } else if (0 == strcmp(argv[1], "normalize")){  
  100.         normalize_net(argv[2], argv[3], argv[4]);  
  101.     } else if (0 == strcmp(argv[1], "rescale")){  
  102.         rescale_net(argv[2], argv[3], argv[4]);  
  103.     } else if (0 == strcmp(argv[1], "ops")){  
  104.         operations(argv[2]);  
  105.     } else if (0 == strcmp(argv[1], "speed")){  
  106.         speed(argv[2], (argc > 3) ? atoi(argv[3]) : 0);  
  107.     } else if (0 == strcmp(argv[1], "partial")){  
  108.         partial(argv[2], argv[3], argv[4], atoi(argv[5]));  
  109.     } else if (0 == strcmp(argv[1], "average")){  
  110.         average(argc, argv);  
  111.     } else if (0 == strcmp(argv[1], "visualize")){  
  112.         visualize(argv[2], (argc > 3) ? argv[3] : 0);  
  113.     } else if (0 == strcmp(argv[1], "imtest")){  
  114.         test_resize(argv[2]);  
  115.     } else {  
  116.         fprintf(stderr, "Not an option: %s\n", argv[1]);  
  117.     }  
  118.     return 0;  
  119. }  
  120.   
  121. ////////////////////////////////////////////////////////////////  
  122. 在yolo.c中寻找到run_yolo函数  
  123. void train_yolo(char *cfgfile, char *weightfile)  
  124. void print_yolo_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)  
  125. void validate_yolo(char *cfgfile, char *weightfile)  
  126. void validate_yolo_recall(char *cfgfile, char *weightfile)  
  127. void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)  
  128.   
  129. void run_yolo(int argc, char **argv)  
  130. {  
  131.     char *prefix = find_char_arg(argc, argv, "-prefix", 0);  
  132.     float thresh = find_float_arg(argc, argv, "-thresh", .2);  
  133.     int cam_index = find_int_arg(argc, argv, "-c", 0);  
  134.     int frame_skip = find_int_arg(argc, argv, "-s", 0);  
  135.     //提取输入参数,4个,格式如下所示  
  136.     if(argc < 4){  
  137.         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);  
  138.         return;  
  139.     }  
  140.   
  141.     char *cfg = argv[3];  
  142.     char *weights = (argc > 4) ? argv[4] : 0;  
  143.     char *filename = (argc > 5) ? argv[5]: 0;  
  144.     if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);  
  145.     //yolo测试图片  
  146.     else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);  
  147.     else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);  
  148.     else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);  
  149.     else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix);  
  150.     //yolo测试webcam demo  
  151. }  
  152.   
  153. void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)  
  154. {  
  155.     image *alphabet = load_alphabet();  
  156.     network net = parse_network_cfg(cfgfile);  
  157.     if(weightfile){  
  158.         load_weights(&net, weightfile);  
  159.     }  
  160. //加载网络权重  
  161.   
  162.     detection_layer l = net.layers[net.n-1];  
  163.     set_batch_network(&net, 1);  
  164. //设置网络  
  165.     srand(2222222);  
  166.     clock_t time;  
  167.     char buff[256];  
  168.     char *input = buff;  
  169.     int j;  
  170.     float nms=.4;  
  171.     box *boxes = calloc(l.side*l.side*l.n, sizeof(box));  
  172.     float **probs = calloc(l.side*l.side*l.n, sizeof(float *));  
  173.     for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));  
  174.     while(1){  
  175.         if(filename){  
  176.             strncpy(input, filename, 256);  
  177.         } else {  
  178.             printf("Enter Image Path: ");  
  179.             fflush(stdout);  
  180.             input = fgets(input, 256, stdin);  
  181.             if(!input) return;  
  182.             strtok(input, "\n");  
  183.         }  
  184.         image im = load_image_color(input,0,0);  
  185.         image sized = resize_image(im, net.w, net.h);  
  186.         float *X = sized.data;  
  187.         time=clock();//计时  
  188.         network_predict(net, X);//预测  
  189.         printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));  
  190.         get_detection_boxes(l, 1, 1, thresh, probs, boxes, 0);  
  191.         if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);  
  192.          
  193.         draw_detections(im, l.side*l.side*l.n, thresh, boxes, probs, voc_names, alphabet, 20);//在原图像中画出边界框  
  194.         save_image(im, "predictions");//保存显示图像  
  195.         show_image(im, "predictions");  
  196.         free_image(im);  
  197.         free_image(sized);  
  198. #ifdef OPENCV  
  199.         cvWaitKey(0);  
  200.         cvDestroyAllWindows();  
  201. #endif  
  202.         if (filename) break;  
  203.     }  
  204. }  
  205. </span>  
c文件虽然多,但是代码井井有条,见名思义。

猜你喜欢

转载自blog.csdn.net/baobei0112/article/details/80075128