02_Entrenamiento del modelo PyTorch [generar conjunto de entrenamiento, conjunto de prueba, conjunto de verificación]

1. Código

import os
base_dir = "E:/pytorch_learning" #修改为当前Data 目录所在的绝对路径 
dataset_dir = os.path.join(base_dir, "Data", "cifar-10-png", "raw_test")
train_dir = os.path.join(base_dir, "Data", "train")
valid_dir = os.path.join(base_dir, "Data", "valid")
test_dir = os.path.join(base_dir, "Data", "test")
print(dataset_dir)
print(train_dir)
print(valid_dir)
print(test_dir)
import glob
import random
import shutil
#生成训练集、测试集、验证集的比例
train_per = 0.8
valid_per = 0.1
test_per = 0.1
def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)
for root, dirs, files in os.walk(dataset_dir):  #dirs为['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        for sDir in dirs:
            #返回一个某一种文件夹下面的某一类型文件路径列表
            imgs_list = glob.glob(os.path.join(root, sDir, '*.png'))  #sDir文件夹中,png列表
            random.seed(666)
            random.shuffle(imgs_list)
            imgs_num = len(imgs_list)   #每个类别都是1000
            train_point = int(imgs_num * train_per)  # 800
            valid_point = int(imgs_num * (train_per + valid_per))  # 900

            for i in range(imgs_num):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sDir) #在训练集创建 0-9 的类别文件夹
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sDir)
                else:
                    out_dir = os.path.join(test_dir, sDir)

                makedir(out_dir)
                out_path = os.path.join(out_dir, os.path.split(imgs_list[i])[-1]) #名字+后缀
                shutil.copy(imgs_list[i], out_path)
                print('Class:{}, train:{}, valid:{}, test:{}'.format(sDir, train_point, valid_point-train_point, imgs_num-valid_point))
            

2. Efecto

 

 

Supongo que te gusta

Origin blog.csdn.net/zhang2362167998/article/details/128806715
Recomendado
Clasificación