Un método universal para cargar conjuntos de datos locales

1. Descripción

1.1 Descripción del formato de ubicación del conjunto de datos

Las diferentes categorías de imágenes en la carpeta del conjunto de datos deben ordenarse primero y colocarse en diferentes subcarpetas . El formato de ubicación es como se muestra en la figura:
Formato de almacenamiento del conjunto de datos
solo hay 2 categorías, por supuesto, varias categorías también están bien. Esto no requiere el número de categorías.

1.2 Descripción de la referencia de función

Para citar la función aquí en otros programas, el método es el siguiente:

import sys      #绝对路径引用,不然引用load_data会报错
#load_data所在程序路径
sys.path.append(r'E:\Pycharm\project\yeah&ok\load_data')	
from load_data import load_data_func,test_image,augment

Generalmente, solo se deben citar load_data_func y test_image.

1.3 Instrucciones sobre cómo utilizar funciones en el programa de conjunto de datos de carga

Después de cargar el programa del conjunto de datos, cargar el conjunto de datos es muy simple, el método de carga es el siguiente:

ata_dir = 'E:\Pycharm\project\yeah&ok\dataset'
Batch_size = 32     #批处理尺寸
train_dataset,test_dataset = load_data_func(data_dir,batchsize=Batch_size)
test_image(train_dataset)	#显示9张图像

Luego, puede continuar construyendo la estructura de la red, realizar entrenamiento y otros pasos.

2. Configure el archivo de la biblioteca (inicio)

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pathlib
import random
import tensorflow_datasets as tfds

3. Función principal

Función principal: La función es ingresar la ruta del conjunto de datos y el tamaño del lote, y devolver el conjunto de entrenamiento y el conjunto de prueba.

def load_data_func(data_dir,batch_size):
    data_root = pathlib.Path(data_dir)  #读取路径,创建path对象
    print(data_dir)
    print(data_root)
    all_image_path = list(data_root.glob('*/*'))    #*/*是获取文件夹下的所有文件及其子文件
    print(all_image_path)
    all_image_path = [str(path) for path in all_image_path] #获取所有图片的完整路径
    print(all_image_path)
    random.shuffle(all_image_path)  #打乱

    label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())    #获取图像文件夹名字
    label_to_index = dict((name, index) for index, name in enumerate(label_names))  #创建字典对象,设置图像名称的映射为整数
    print(label_to_index)   #OK:0,Yeah:1
	# 获取所有图像对应的标签
    all_image_label = [label_to_index[pathlib.Path(p).parent.name]for p in all_image_path]  #获取每个图象的父类名称,并变成数值,0101...
    print(len(all_image_label))	#显示获取的数据量
    index_to_label = dict((v,k) for k,v in label_to_index.items())  #获取数值对应的标签名字,以备后用

    image_patn = all_image_path[5]
    image_show = (1 + load_preprocess_image(image_patn)) / 2.  # 要变成image/255.才能正常显示
    plt.imshow(image_show)  # 这里是测试图片能不能正常显示
    plt.show()

    path_ds = tf.data.Dataset.from_tensor_slices(all_image_path)
    image_dataset = path_ds.map(load_preprocess_image)  # 这里才是把所有图片提取出来,前面的都是路径

    label_dataset = tf.data.Dataset.from_tensor_slices(all_image_label)
    dataset = tf.data.Dataset.zip((image_dataset, label_dataset))  # 做成数据集,zip将label和image对应起来
    image_count = len(all_image_path)  # 数据集的数量
    test_count = int(image_count * 0.2)
    train_count = image_count - test_count
    print(test_count, train_count)

    train_dataset = dataset.skip(test_count)  # 跳过test_count构成数据集
    test_dataset = dataset.take(test_count)  # 取test_count构成数据集
    BATCH_SIZE = batch_size # buffer_size = train_count
    train_dataset = train_dataset.shuffle(buffer_size=150).repeat(3).batch(BATCH_SIZE) 
    # 数据集数量不够则加个.repeat()
    test_dataset = test_dataset.batch(BATCH_SIZE)
    # 数据增强,OK,之前打乱过了,只需要对训练集数据增强
    train_dataset = train_dataset.map(augment)
    return train_dataset,test_dataset

4. Extraiga imágenes de la ruta y normalícelas

def load_preprocess_image(img_path):
    img_raw = tf.io.read_file(img_path)           #读取路径
    img_tensor = tf.image.decode_jpeg(img_raw,channels=3)   #解码图片 decode_image通用,但不会返回shape,改成对应的格式
    img_tensor = tf.image.resize(img_tensor,[160,160])      #改变图片大小
    img_tensor = tf.cast(img_tensor, tf.float32)  #转换数据类型
    img = img_tensor/127.5-1                   #标准化,归一化
    return img

5. Funciones para la mejora de datos de imágenes

Elija según sus necesidades.

def augment(image,label):
    #随机进行水平翻转
    image = tf.image.random_flip_left_right(image)
    #随机设置对比度
    image = tf.image.random_contrast(image,lower=0.0,upper=1.0)
    #垂直翻转
    image = tf.image.random_flip_up_down(image)
    #设置亮度
    image = tf.image.random_brightness(image,max_delta=0.5)
    #设置色度
    image = tf.image.random_hue(image,max_delta=0.3)
    #设置饱和度
    image = tf.image.random_saturation(image,lower=0.3,upper=0.5)
    return image,label

6. Muestra 9 imágenes, que se pueden usar para ver el efecto de imagen después de la mejora de datos

Esta función llevará mucho tiempo y no es necesario llamarla cada vez.

def test_image(train_dataset):
    #用一次就行了
    plt.figure(figsize=(12,12))
    for batch in tfds.as_numpy(train_dataset):  #这里耗时间很久。。尽量不用
        for i in range(9):
            image, label = (1+batch[0][i])/2., batch[1][i]   #image前面进行了归一化,因此这里要先恢复过来,才能正常显示图像
            plt.subplot(3,3,i+1)
            plt.imshow(image)
            plt.grid(False)
        break
    plt.show()

Supongo que te gusta

Origin blog.csdn.net/weixin_45371989/article/details/106319320
Recomendado
Clasificación