A universal method of loading local data sets

1. Description

1.1 Data set placement format description

The different categories of pictures in the data set folder need to be sorted first and placed in different subfolders . The placement format is as shown in the figure:
Data set storage format
there are only 2 categories, of course, multiple categories are also OK. This does not require the number of categories.

1.2 Function reference description

To quote the function here in other programs, the method is as follows:

import sys      #绝对路径引用,不然引用load_data会报错
#load_data所在程序路径
sys.path.append(r'E:\Pycharm\project\yeah&ok\load_data')	
from load_data import load_data_func,test_image,augment

Generally, only load_data_func and test_image need to be quoted.

1.3 Instructions on how to use functions in the load data set program

After loading the data set program, loading the data set is very simple, the loading method is as follows:

ata_dir = 'E:\Pycharm\project\yeah&ok\dataset'
Batch_size = 32     #批处理尺寸
train_dataset,test_dataset = load_data_func(data_dir,batchsize=Batch_size)
test_image(train_dataset)	#显示9张图像

Then you can continue to build the network structure, perform training and other steps.

2. Configure the library file (start)

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import random
import tensorflow_datasets as tfds

3. Main function

Main function: The function is to input the data set path and batch size, and return the training set and test set.

def load_data_func(data_dir,batch_size):
    data_root = pathlib.Path(data_dir)  #读取路径,创建path对象
    print(data_dir)
    print(data_root)
    all_image_path = list(data_root.glob('*/*'))    #*/*是获取文件夹下的所有文件及其子文件
    print(all_image_path)
    all_image_path = [str(path) for path in all_image_path] #获取所有图片的完整路径
    print(all_image_path)
    random.shuffle(all_image_path)  #打乱

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())    #获取图像文件夹名字
    label_to_index = dict((name, index) for index, name in enumerate(label_names))  #创建字典对象,设置图像名称的映射为整数
    print(label_to_index)   #OK:0,Yeah:1
	# 获取所有图像对应的标签
    all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path]  #获取每个图象的父类名称,并变成数值,0101...
    print(len(all_image_label))	#显示获取的数据量
    index_to_label = dict((v,k) for k,v in label_to_index.items())  #获取数值对应的标签名字,以备后用

    image_patn = all_image_path[5]
    image_show = (1 + load_preprocess_image(image_patn)) / 2.  # 要变成image/255.才能正常显示
    plt.imshow(image_show)  # 这里是测试图片能不能正常显示
    plt.show()

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
    image_dataset = path_ds.map(load_preprocess_image)  # 这里才是把所有图片提取出来,前面的都是路径

    label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
    dataset = tf.data.Dataset.zip((image_dataset, label_dataset))  # 做成数据集,zip将label和image对应起来
    image_count = len(all_image_path)  # 数据集的数量
    test_count = int(image_count * 0.2)
    train_count = image_count - test_count
    print(test_count, train_count)

    train_dataset = dataset.skip(test_count)  # 跳过test_count构成数据集
    test_dataset = dataset.take(test_count)  # 取test_count构成数据集
    BATCH_SIZE = batch_size # buffer_size = train_count
    train_dataset = train_dataset.shuffle(buffer_size=150).repeat(3).batch(BATCH_SIZE) 
    # 数据集数量不够则加个.repeat()
    test_dataset = test_dataset.batch(BATCH_SIZE)
    # 数据增强,OK,之前打乱过了,只需要对训练集数据增强
    train_dataset = train_dataset.map(augment)
    return train_dataset,test_dataset

4. Extract pictures from the path and normalize them

def load_preprocess_image(img_path):
    img_raw = tf.io.read_file(img_path)           #读取路径
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)   #解码图片 decode_image通用,但不会返回shape,改成对应的格式
    img_tensor = tf.image.resize(img_tensor,[160,160])      #改变图片大小
    img_tensor = tf.cast(img_tensor, tf.float32)  #转换数据类型
    img = img_tensor/127.5-1                   #标准化,归一化
    return img

5. Functions for data enhancement of pictures

Choose according to your needs.

def augment(image,label):
    #随机进行水平翻转
    image = tf.image.random_flip_left_right(image)
    #随机设置对比度
    image = tf.image.random_contrast(image,lower=0.0,upper=1.0)
    #垂直翻转
    image = tf.image.random_flip_up_down(image)
    #设置亮度
    image = tf.image.random_brightness(image,max_delta=0.5)
    #设置色度
    image = tf.image.random_hue(image,max_delta=0.3)
    #设置饱和度
    image = tf.image.random_saturation(image,lower=0.3,upper=0.5)
    return image,label

6. Display 9 pictures, which can be used to see the picture effect after data enhancement

This function will be time consuming, and there is no need to call it every time.

def test_image(train_dataset):
    #用一次就行了
    plt.figure(figsize=(12,12))
    for batch in tfds.as_numpy(train_dataset):  #这里耗时间很久。。尽量不用
        for i in range(9):
            image, label = (1+batch[0][i])/2., batch[1][i]   #image前面进行了归一化,因此这里要先恢复过来,才能正常显示图像
            plt.subplot(3,3,i+1)
            plt.imshow(image)
            plt.grid(False)
        break
    plt.show()

Guess you like

Origin blog.csdn.net/weixin_45371989/article/details/106319320