darknet源码解读-train_detector

        在darknet框架上运行类似以下训练实例时必然会进入到train_detector函数,它是训练目标检测器的入口函数。

        ./darknet detector train cfg/coco.data cfg/yolov2.cfg darknet19_448.conv.23

        ./darknet detector train cfg/coco.data cfg/yolov2.cfg darknet19_448.conv.23 -gpus 0,1,2,3,4

//if not define -gpus,gpus=0,ngpus=1
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
    list *options = read_data_cfg(datacfg);
    char *train_images = option_find_str(options, "train", "data/train.list");
	//store weights?
    char *backup_directory = option_find_str(options, "backup", "/backup/");

    srand(time(0));

	//from /a/b/yolov2.cfg extract yolov2
    char *base = basecfg(cfgfile); //network config
    printf("%s\n", base);
	
    float avg_loss = -1;
    network **nets = calloc(ngpus, sizeof(network));

    srand(time(0));
    int seed = rand();
    int i;
    for(i = 0; i < ngpus; ++i){
        srand(seed);
#ifdef GPU
        cuda_set_device(gpus[i]);
#endif
		//create network for every GPU
        nets[i] = load_network(cfgfile, weightfile, clear);
        nets[i]->learning_rate *= ngpus;
    }
    srand(time(0));
    network *net = nets[0];

	//subdivisions,why not divide?
    int imgs = net->batch * net->subdivisions * ngpus;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net->learning_rate, net->momentum, net->decay);

	data train, buffer;

    //the last layer e.g. [region] for yolov2
    layer l = net->layers[net->n - 1];

    int classes = l.classes; 
    float jitter = l.jitter;

    list *plist = get_paths(train_images);
    //int N = plist->size;
    char **paths = (char **)list_to_array(plist);

    load_args args = get_base_args(net);
    args.coords = l.coords;
    args.paths = paths;
    args.n = imgs;        //一次加载的数量       
    args.m = plist->size; //总的图片数量
    args.classes = classes;
    args.jitter = jitter;
    args.num_boxes = l.max_boxes;
    args.d = &buffer;
    args.type = DETECTION_DATA;
    //args.type = INSTANCE_DATA;
    args.threads = 64;

	//n张图片以及图片上的truth box会被加载到buffer.X,buffer.y里面去
    pthread_t load_thread = load_data(args); 

	double time;
    int count = 0;
    //while(i*imgs < N*120){
    while(get_current_batch(net) < net->max_batches){
		//l.random决定是否多尺度,如果要的话每训练10个batch进行一下下面的操作
        if(l.random && count++%10 == 0){
            printf("Resizing\n");
			//这个会随机产生{320,352,...608}这样的尺寸
            int dim = (rand() % 10 + 10) * 32;

			//意思是最后的200个batch图片都缩放到608
            if (get_current_batch(net)+200 > net->max_batches) dim = 608;
            //int dim = (rand() % 4 + 16) * 32;
            printf("%d\n", dim);
            args.w = dim;
            args.h = dim;

            pthread_join(load_thread, 0); //wait for load_thread ternimate
            train = buffer; 
            free_data(train);
            load_thread = load_data(args);

            #pragma omp parallel for
            for(i = 0; i < ngpus; ++i){
				//要调整网络
                resize_network(nets[i], dim, dim);
            }
            net = nets[0];
        }
		
        time=what_time_is_it_now();
		//args.n数量的图像由args.threads个子线程加载完成,该线程会退出
        pthread_join(load_thread, 0); 
		//加载完成的args.n张图像会存入到args.d中
        train = buffer;

		//next batch?
        load_thread = load_data(args);

        printf("Loaded: %lf seconds\n", what_time_is_it_now()-time);

        time=what_time_is_it_now();
        float loss = 0;
#ifdef GPU
        if(ngpus == 1){
            loss = train_network(net, train);
        } else {
            loss = train_networks(nets, ngpus, train, 4);
        }
#else
        loss = train_network(net, train);
#endif
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;

        i = get_current_batch(net);
        printf("%ld: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), what_time_is_it_now()-time, i*imgs);
        if(i%100==0){
#ifdef GPU
            if(ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
            char buff[256];
            sprintf(buff, "%s/%s.backup", backup_directory, base);
            save_weights(net, buff);
        }
        if(i%10000==0 || (i < 1000 && i%100 == 0)){
#ifdef GPU
            if(ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
            save_weights(net, buff);
        }
		//这里要相当注意,train指针指向的空间来自于buffer,而buffer中的空间来自于load_data函数
		//后续逻辑中动态分配的空间,而在train被赋值为buffer以后,在下一次load_data逻辑中会
        //再次动态分配,这里一定要记得释放前一次分配的,否则指针将脱钩,内存泄漏不可避免
        free_data(train);
    }
#ifdef GPU
    if(ngpus != 1) sync_nets(nets, ngpus, 0);
#endif
    char buff[256];
    sprintf(buff, "%s/%s_final.weights", backup_directory, base);
    save_weights(net, buff);
}

  while循环里每一次循环代表一次训练迭代,一次训练的数据量为imgs,它等于net->batch * net->subdivisions * ngpus,我只考虑CPU的情形或只考虑只含单GPU的情形的话,ngpus就等于1。而subdivisions这个参数,我所观察到大部分cfg文件(如:yolov2.cfg)中的默认设置都为1,如果该参数不为1的话,while循环中一次加载的图像数量就是net->batch * net->subdivisions(后面人都只考虑cpu情形,所以ngpus为1),否则的话就是一个net->batch(在cfg文件中会有明确的定义)。有了数据之后紧接就可以开始训练,进入到train_network函数中。

float train_network(network *net, data d)
{
    assert(d.X.rows % net->batch == 0);
    int batch = net->batch;
    int n = d.X.rows / batch;

    int i;
    float sum = 0;
    for(i = 0; i < n; ++i){
		//d.X.rows is net->batch * net->subdivisions * ngpus?
		//this batch is not that batch?
        get_next_batch(d, batch, i*batch, net->input, net->truth);
        float err = train_network_datum(net);
        sum += err;
    }

	//calc average loss
    return (float)sum/(n*batch);
}

首先解释一下变量n,它等于d.X.rows/batch,d.X.rows就是我们上面在一次while循环中准备的数据量imgs,考虑到ngpus=1,那么这里求出来的n实际就应该等于subdivisions。后面就相当于每次取一个batch的数据,训练n次。

float train_network_datum(network *net)
{
    *net->seen += net->batch; //更新已经参与训练的图片数量
    net->train = 1;
    forward_network(net);
    backward_network(net);
    float error = *net->cost;
    if(((*net->seen)/net->batch)%net->subdivisions == 0) update_network(net);
    return error;
}

意思已经比较明显了,一次迭代训练一个batch的数据,包含前向传播(forward_network),反向传播(backward_network)以及网络更新(update_network)。

猜你喜欢

转载自blog.csdn.net/ChuiGeDaQiQiu/article/details/81276844
今日推荐