MTCNN-Caffe(二)生成训练集、验证集的list,混合

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xiaomifanhxx/article/details/86583596

上篇博客讲述了在训练caffe模型时,生成了3个txt文件,再分成训练集和验证集后要生成整体的train/val.txt,该py文件将整理好train与valid。默认比例是pos:neg:part:landmark为1:3:1:0

#!/usr/bin/env python
"""
classify.py is an out-of-the-box image classifer callable from the command line.

By default it configures and runs the Caffe reference ImageNet model.
"""
import os
import sys
import argparse
import glob
import time
import random

def view_bar(num, total):
    rate = float(num) / total
    rate_num = int(rate * 100)+1
    r = '\r[%s%s]%d%%' % ("#"*rate_num, " "*(100-rate_num), rate_num, )
    sys.stdout.write(r)
    sys.stdout.flush()

def main(argv):
    pycaffe_dir = os.path.dirname(__file__)

    parser = argparse.ArgumentParser()
    # Required arguments: input and output files.
    parser.add_argument(
        "pos_file",
        type=str,
        help="positive sample list"
    )
    parser.add_argument(
        "neg_file",
        type=str,
        help="negative sample list"
    )    
    parser.add_argument(
        "part_file",
        type=str,
        help="partial sample list"
    )
    parser.add_argument(
        "landmark_file",
        type=str,
        help="landmark sample list"
    )
    parser.add_argument(
        "sample_percents",
        type=str,
        default='1:3:1:2',
        help="landmark sample list"
    )
    parser.add_argument(
        "output_file",
        type=str,
        help="output list"
    )

    args = parser.parse_args()
    sample_percents = [int(s) for s in args.sample_percents.split(':')]
    if len(sample_percents) != 4:
        print("sample percents must have 4 numbers")
        exit(0)

    pos_list = []
    with open(args.pos_file) as f:
        pos_list = f.readlines()
        random.shuffle(pos_list)

    neg_list = []
    with open(args.neg_file) as f:
        neg_list = f.readlines()
        random.shuffle(neg_list)
    
    part_list = []
    with open(args.part_file) as f:
        part_list = f.readlines()
        random.shuffle(part_list)
        
    landmark_list = []
    with open(args.landmark_file) as f:
        landmark_list = f.readlines()
        random.shuffle(landmark_list)

    f1 = open(args.output_file, 'w')
    pos_idx = 0
    neg_idx = 0
    part_idx = 0
    landmark_idx = 0

    total_num = len(pos_list)
    for pos in pos_list:
        view_bar(pos_idx,total_num)
        pos_idx += 1

        f1.write(pos.strip()+"\n")

        for idx in range(0, sample_percents[1]):
            f1.write(neg_list[neg_idx%len(neg_list)].strip()+"\n")
            neg_idx += 1
            if neg_idx == len(neg_list):
                random.shuffle(neg_list)
                neg_idx = 0
        
        for idx in range(0, sample_percents[2]):
            f1.write(part_list[part_idx%len(part_list)].strip()+"\n")
            part_idx += 1
            if part_idx == len(part_list):
                random.shuffle(part_list)
                part_idx = 0
            
        for idx in range(0, sample_percents[3]):
            f1.write(landmark_list[landmark_idx%len(landmark_list)].strip()+"\n")
            landmark_idx += 1
            if landmark_idx == len(landmark_list):
                random.shuffle(landmark_list)
                landmark_idx = 0

    f1.close()

if __name__ == '__main__':
    main(sys.argv)

猜你喜欢

转载自blog.csdn.net/xiaomifanhxx/article/details/86583596