Eye of Depth Pytorch punch card (6): The method of dividing the data set into training set, validation set and test set

Preface


   Deep learning is ultimately driven by data, so data is very important. The data we collect on the Internet is often not divided into training set, validation set and test, and we need to divide it ourselves. The code of this note refers to the reference code provided by the teacher of Deep Eyes , and has made some related expansions.

   Code and data set: dataSplit.zip


task


  Collecting different types of data on the Internet, the author collected two types of data, antsand bees100 sheets of each , and placed them in old_datathe two folders below. Split the data set into training set, validation set and test set.


Data set split


  Splitting the data set means randomly assigning data to the training set, validation set and test set. There are two ways. The first is to split the data directly, that is, to split directly on the physical address; the second is to generate the path list files of the training set, validation set and test set separately, and then cut during the reading process. Minute.

  • Brief analysis of functions to be used

  os.listdir() : Get a list of files and folders in a given path, but cannot get a list of files and folders in the next level directory. Such as:

print(os.listdir(old_path))
# output
# ['ants', 'bees']

  os.walk(): This method has 4 parameters, we just use one, that istop, the path to be traversed. This method can traverse all the file information under the given path, and return a triplet for each directory (including itself) under the given path, namely (root, dirs, files). rootRefers to the address of the current traversal folder. Itdirsis onelist, and the content is the names of all directories in the folder (not including subdirectories).filesSimilarlylist, the content is all the files in the folder. Such as:

root_dir, sub_dirs, sub_sub_dirs  = os.walk(old_path)
print(root_dir, '\n', sub_dirs, '\n', sub_dirs1)
# output
# ('old_data', ['ants', 'bees'], []) 
# ('old_data\\ants', [], ['1030023514_aad5c608f9.jpg', '1095476100_3906d8afde.jpg', .....])
# ('old_data\\bees', [], ['1092977343_cb42b38d62.jpg', '1093831624_fb5fbe2308.jpg', .....]

for root_dir, sub_dirs, files in os.walk(old_path):                     # 遍历os.walk()返回的每一个三元组,内容分别放在三个变量中
    print('root_dir:', root_dir, 'sub_dirs:', sub_dirs, 'file:', files)
    
# output
# root_dir: old_data sub_dirs: ['ants', 'bees'] file: []
# root_dir: old_data\ants sub_dirs: [] file: ['1030023514_aad5c608f9.jpg', .....]    
# root_dir: old_data\bees sub_dirs: [] file: ['1092977343_cb42b38d62.jpg', .....]

  os.path.join(): Concatenate characters into directories.
  shuffle(): Randomly sort all the elements of the sequence
  filter():filter(function, iterable) used to filter the sequence, filter out the elements that do not meet the conditions, and return an iterator object. If you want to convert to a list, you can use list() to convert. This function receives two parameters, the first is the judgment function, the second is the sequence, each element of the sequence is passed as a parameter to the function for judgment, then it returns True or False, and finally the element that returns True is placed in the new list . The following means to remove non- .jpgfiles in the img_names list and convert the result into a list.

list(filter(lambda x: x.endswith('.jpg'), img_names))
  • The first type: directly segment data

  To use shutil package to copy data. The idea is to traverse all the files in the directory to get the storage paths of all categories. Iterate through each category, and then get a list of the names of all pictures in each category and arrange them randomly. Then create a folder to store the corresponding data set, and traverse the list of image names, and assign names to different sets according to the ratio (8:1:1). Finally, stitch the path and copy the picture to the corresponding location according to the path. code show as below:

import os
import random
import math
import shutil

def data_split(old_path):
    new_path = 'data'
    if os.path.exists('data') == 0:
        os.makedirs(new_path)
    for root_dir, sub_dirs, file in os.walk(old_path):                               # 遍历os.walk()返回的每一个三元组,内容分别放在三个变量中
        for sub_dir in sub_dirs:
            file_names = os.listdir(os.path.join(root_dir, sub_dir))                 # 遍历每个次级目录
            file_names = list(filter(lambda x: x.endswith('.jpg'), file_names))      # 去掉列表中的非jpg格式的文件

            random.shuffle(file_names)
            for i in range(len(file_names)):
                if i < math.floor(0.8*len(file_names)):
                    sub_path = os.path.join(new_path, 'train_set', sub_dir)
                elif i < math.floor(0.9*len(file_names)):
                    sub_path = os.path.join(new_path, 'val_set', sub_dir)
                elif i < len(file_names):
                    sub_path = os.path.join(new_path, 'test_set', sub_dir)
                if os.path.exists(sub_path) == 0:
                    os.makedirs(sub_path)

                shutil.copy(os.path.join(root_dir, sub_dir, file_names[i]), os.path.join(sub_path, file_names[i]))   # 复制图片,从源到目的地

if __name__ == '__main__':
    data_path = 'old_data'
    data_split(data_path)

  Result:
Insert picture description here
  Second: Generate a list of paths

  The code is similar to the above method. This method just saves the paths to three different text files. This method can directly generate the path list without regenerating it when loading data, which simplifies the data reading process. code show as below:

def data_list(path):

    for root_dir, sub_dirs, _ in os.walk(path):                               # 遍历os.walk()返回的每一个三元组,内容分别放在三个变量中
        idx = 0
        for sub_dir in sub_dirs:
            file_names = os.listdir(os.path.join(root_dir, sub_dir))                 # 遍历每个次级目录
            file_names = list(filter(lambda x: x.endswith('.jpg'), file_names))      # 去掉列表中的非jpg格式的文件

            random.shuffle(file_names)
            for i in range(len(file_names)):
                if i < math.floor(0.8 * len(file_names)):
                    txt_name = 'train_set.txt'
                elif i < math.floor(0.9 * len(file_names)):
                    txt_name = 'val_set.txt'
                elif i < len(file_names):
                    txt_name = 'test_set.txt'
                with open(os.path.join(path, txt_name), mode='a') as file:
                    file.write(str(idx) + ',' + os.path.join(path, sub_dir, file_names[i]) + '\n')     # 为了以后好用,修改了这里,将' '改成了',',另外路径加了sub_dir
            idx += 1

          
if __name__ == '__main__':
    data_path = 'old_data'
    data_list(data_path)

  result:
Insert picture description here


reference


  https://ai.deepshare.net/detail/p_5df0ad9a09d37_qYqVmt85/6
  https://www.runoob.com/python/os-walk.html
  https://www.runoob.com/python/os-listdir.html

Guess you like

Origin blog.csdn.net/sinat_35907936/article/details/105611737