Randomly generate training set and test set from the picture folder

Randomly generate training set and test set. The
code runs in the upper level of the picture directory. It is assumed that each picture in the picture directory has a markup file in xml or txt format.

# -*- coding: UTF-8 -*-
# import xml.etree.ElementTree as ET
import os
import random


if __name__ == "__main__":
    wd = os.getcwd()  # 获取当前文件目录
    dir_list = os.listdir(wd)
    dir_pic = []
    for item in dir_list:
        if os.path.isdir(item):
            # dir.append(item)
            dir_path = os.path.join(wd, item)
            dir_list_temp = os.listdir(dir_path)
            for item2 in dir_list_temp:
                if os.path.isfile(dir_path.replace('\\', '/') + '/' + item2):
                    try:
                        # print(item2.split('.')[0])
                        if item2.split('.')[1] == 'jpg' or item2.split('.')[1] == 'JPG':
                            dir_pic.append(dir_path + '/' + item2)  # 此处搜集完所有xml文件路径名称放在dir中
                    except IndexError:
                        pass
                        # print('IndexError, 可能遇到了不带后缀名的文件')
    # ************************************************************
    # 生成随机数用于随机分配训练集和测试集
    probe = random.randint(1, 100)

    print("Probability: %d" % probe)

    train_file = open('train_file.txt', 'w')
    test_file = open('test_file.txt', 'w')

    for item in dir_pic:
        item = item.replace('\\', '/')
        probe = random.randint(1, 100)
        if probe < 75:
            train_file.write(item + '\n')
        else:
            test_file.write(item + '\n')

    train_file.close()
    test_file.close()
    
    nameFile_dir = wd.replace('\\', '/') + '/cfg/taco.names'
    count = len(open(nameFile_dir , 'r').readlines())

    with open('cfg/taco.data', 'w') as f:
        f.write('classes=' + str(count) + '\n')
        f.write('train = ' + wd.replace('\\', '/') + '/train_file.txt' + '\n')
        f.write('valid = ' + wd.replace('\\', '/') + '/test_file.txt' + '\n')
        f.write('names = ' + wd.replace('\\', '/') + '/cfg/taco.names' + '\n')
        f.write('backup = ' + wd.replace('\\', '/') + '/backup'+'\n')

    print(wd.replace('\\', '/')  + '/cfg/taco.data')
    print(wd.replace('\\', '/') + '/cfg/yolov4-tiny.cfg')
    print(wd.replace('\\', '/') + '/pre_trained/yolov4-tiny.conv.29')
    print(wd.replace('\\', '/') + '/backup')


Guess you like

Origin blog.csdn.net/ohhardtoname/article/details/115299636