PyTorch学習メモ(18) -- トレーニングセットとテストセットを分割するスクリプト

PyTorch学習メモ(18) – トレーニングセットとテストセットを分割するためのスクリプトファイル

    このブログ記事はPyTorchの学習メモ、第18回目となるコンテンツ記録で、主にトレーニングセットとテストセットを自動で分割する方法を記録しています。これには主に 2 つの方法があり、1 つは分類問題に適したデータセットをカテゴリごとに複数のフォルダーに保存することを目的とし、同じカテゴリの画像をトレーニング セットとテスト セットに分割することを目的としています。データは分類ごとに保存されるのではなく、同じフォルダーに直接配置され、データはトレーニング セットとテスト セットに分割されます。

1. カテゴリ別に保管する

    トレーニング セットとテスト セットを分割する場合、分割する必要があるフォルダーはflower_data/flower_photosで、その下にはデイジー、タンポポ、バラ、ヒマワリ、チューリップの5つの分類フォルダーがあり、分類用のスクリプトはsplit_dataです。 py では、スクリプト ファイル split_data.py とデータ フォルダー flowers_photos が並列関係にあり、下図に示すように、これらはすべてflower_data フォルダーの下に配置されます。
ここに画像の説明を挿入
    スクリプト ファイル Split_data.py のコードは次のとおりです。

# coding :UTF-8
# 文件功能: 代码实现自动将数据集划分为训练集和验证集的功能
# 开发人员: XXX
# 开发时间: 2021/12/3 6:07 下午
# 文件名称: split_data.py
# 开发工具: PyCharm

import os
from shutil import copy, rmtree
import random


def mk_file(file_path: str):
    if os.path.exists(file_path):
        # 如果文件夹存在,则先删除原文件夹再重新创建
        rmtree(file_path)
    os.makedirs(file_path)


def main():
    # 保证随机可复现
    random.seed(0)

    # 将数据集中10%的数据划分到验证集中
    split_rate = 0.1

    # 指向你解压后的flower_photos文件夹
    cwd = os.getcwd()  # 用于返回当前工作目录
    data_root = os.path.join(cwd, "flower_data")
    origin_flower_path = os.path.join(data_root, "flower_photos")
    assert os.path.exists(origin_flower_path), "path '{}' does not exist.".format(origin_flower_path)

    flower_class = [cla for cla in os.listdir(origin_flower_path)
                    if os.path.isdir(os.path.join(origin_flower_path, cla))]

    # 建立保存训练集的文件夹
    train_root = os.path.join(data_root, "train")
    mk_file(train_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(train_root, cla))

    # 建立保存验证集的文件夹
    val_root = os.path.join(data_root, "val")
    mk_file(val_root)
    for cla in flower_class:
        # 建立每个类别对应的文件夹
        mk_file(os.path.join(val_root, cla))

    for cla in flower_class:
        cla_path = os.path.join(origin_flower_path, cla)
        images = os.listdir(cla_path)
        num = len(images)
        # 随机采样验证集的索引
        eval_index = random.sample(images, k=int(num * split_rate))
        for index, image in enumerate(images):
            if image in eval_index:
                # 将分配至验证集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(val_root, cla)
                copy(image_path, new_path)
            else:
                # 将分配至训练集中的文件复制到相应目录
                image_path = os.path.join(cla_path, image)
                new_path = os.path.join(train_root, cla)
                copy(image_path, new_path)
            print("\r[{}] processing [{}/{}]".format(cla, index + 1, num), end="")  # processing bar
        print()

    print("processing done!")


if __name__ == '__main__':
    main()

2. すべてが 1 つのフォルダーに保存されます

    すべての画像が 1 つのフォルダーに保存されている場合、このフォルダー内の画像はトレーニング セットとテスト セットに分割されます。Windows のバージョン コードは次のとおりです。

# -*- coding: utf-8 -*-
"""
Created on Tue Jul 20 16:28:13 2021

@author: NN
"""

import os
import random
import shutil

# 原始数据集路径
# origion_path = r'D:\蓝藻门'
origion_path = r'E:\BaiduNetdiskDownload\data_20211112_train_new'
names = os.listdir(origion_path)

# 保存路径
# save_train_dir = r'D:\藻类识别神经网络\分类网络\train'
# save_test_dir = r'D:\藻类识别神经网络\分类网络\test'

# 数据集类别及数量
for i in names:
    file_list = origion_path + '\\' + i
    image_list = os.listdir(file_list) # 获取图片的原始路径
    image_number = len(image_list)

    train_number = int(image_number * 0.75)
    train_sample = random.sample(image_list, train_number) # 从image_list中随机获取0.8比例的图像.
    test_sample = list(set(image_list) - set(train_sample))
    
    # 创建保存路径
    save_train_dir = r'D:\藻类数据\data_20211112\train' + '\\' + i
    save_test_dir = r'D:\藻类数据\data_20211112\test' + '\\' + i
    if not os.path.isdir(save_train_dir):
        os.makedirs(save_train_dir)
     
    if not os.path.isdir(save_test_dir):
        os.makedirs(save_test_dir)
        
    # 复制图像到目标文件夹
    for j in train_sample:
        shutil.copy(file_list + '\\' + j, save_train_dir)
    
    for k in test_sample:
        shutil.copy(file_list + '\\' + k, save_test_dir)

おすすめ

転載: blog.csdn.net/weixin_43981621/article/details/121706377