According to 80% of the training set, 20% of the test set
Because the pictures in the folder have many files with the same name, the pictures will be overwritten when the training set and test set are generated. In order to avoid making adjustments to the code
import codecs
import os
import random
import shutil
from PIL import Image
train_ratio = 4.0 / 5#80%
all_file_dir = './xh'#文件目录xh文件夹里有上述待处理的文件夹
class_list = [c for c in os.listdir(all_file_dir) if
os.path.isdir(os.path.join(all_file_dir, c)) and not c.endswith('Set') and not c.startswith('.')]
# 用于返回文件夹的名称
class_list.sort() # 排序
print(class_list)
train_image_dir = os.path.join(all_file_dir, "trainImageSet") # 这一行是路径的拼接trainImageSet放在all_file_dir路径下
if not os.path.exists(train_image_dir): # 查看文件是否存在
os.makedirs(train_image_dir)
eval_image_dir = os.path.join(all_file_dir, "evalImageSet")
if not os.path.exists(eval_image_dir):
os.makedirs(eval_image_dir)
train_file = codecs.open(os.path.join(all_file_dir, "train.txt"), 'w') # 打开文件,w为只读
eval_file = codecs.open(os.path.join(all_file_dir, "eval.txt"), 'w')
with codecs.open(os.path.join(all_file_dir, "label_list.txt"), "w") as label_list:
label_id = 0
for class_dir in class_list:
label_list.write("{0}\t{1}\n".format(label_id, class_dir)) # label_id, class_dir存在label_list上
image_path_pre = os.path.join(all_file_dir, class_dir)
for file in os.listdir(image_path_pre): # 返回这个目录下有什么文件
try:
img = Image.open(os.path.join(image_path_pre, file))
if random.uniform(0, 1) <= train_ratio:
shutil.copyfile(os.path.join(image_path_pre, file),
os.path.join(train_image_dir, (class_dir + file))) # 复制一个文件到另一个文件中,(class_dir + file)为了
避免每个文件夹的重名文件,在文件名前加上类别名
train_file.write("{0} {1}\n".format(os.path.join("trainImageSet", (class_dir + file)), label_id))
else:
shutil.copyfile(os.path.join(image_path_pre, (class_dir + file)), os.path.join(eval_image_dir, file))
eval_file.write("{0} {1}\n".format(os.path.join("evalImageSet", (class_dir + file)), label_id))
except Exception as e:
pass
# 存在一些文件打不开,此处需要稍作清洗
label_id += 1
train_file.close()
eval_file.close()
Obtained file:
Code reference: Classification of surface defects of hot-rolled steel strip based on PaddleClas
https://aistudio.baidu.com/aistudio/projectdetail/685319