搭建简单图片分类的卷积神经网络(一)-- 训练模型的图片数据预处理

一、训练之前数据的预处理主要包括两个方面

1、将图片数据统一格式,以标签来命名并存到train文件夹中(假设原始图片按类别存到文件夹中)。

2、对命名好的图片进行训练集和测试集的划分以及图片数据化。

先对整个项目文件进行说明:

项目文件夹

image文件里是用来对模型测试的未处理图片(训练模型不用)。

logs文件存放训练好的模型。

prediction文件是image文件中图片经过模型测试后分类的图片。

train文件有两个文件子层orig_data和train_data,前一个是未处理训练模型图片,后一个是处理好的进行模型训练的图片。

.py文件是项目程序,其他项目自带的,无关。

二、OK!现在先进行第一步,新建IntputData.py文件

import os
from PIL import Image
#未处理图片位置
orig_picture = r'E:\PycharmPython\NewCnn\train\orig_data'

#已处理图片存储位置
gen_picturn = r'E:\PycharmPython\NewCnn\train\train_data'

#查询需要分类的类别以及总样本个数
classes = []
num_samples = 0

for str_classes in os.listdir(orig_picture):
    classes.append(str_classes)

#统一图片大小
def get_traindata(orig_dir,gen_dir,classes):
    i = 0
    for index,name in enumerate(classes):
        class_path = orig_dir + '\\' + name + '\\' #扫描原始图片
        gen_train_path = gen_dir +'\\' + name   #判断是否有文件夹
        folder = os.path.exists(gen_train_path)
        if not folder :
            os.makedirs(gen_train_path)
            print(gen_train_path,'new file')
        else:
            print('There is this flie')
        #给图片加编号保存
        for imagename_dir in os.listdir(class_path):
            i += 1
            origimage_path = class_path + imagename_dir
            #统一格式
            image_data = Image.open(origimage_path).convert('RGB')
            image_data = image_data.resize((64,64))
            image_data.save(gen_train_path + '\\'+str(index) + name + str(i) + '.jpg' )
            num_samples = i
    print('picturn :%d' % num_samples)

if __name__ == '__main__':
    get_traindata(orig_picture,gen_picturn,classes)

这段程序将原始图片统一为64X64格式,并分类保存。

三、然后新建GetCnnData文件

import os
import math
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf


#加载图片路径
train_data_dir = r'E:\PycharmPython\NewCnn\train\train_data'
classes = []
image_list = []
label_list = []



def get_files(file_path,ratdio):
    for str_classes in os.listdir(train_data_dir):
        classes.append(str_classes)
    for index, name in enumerate(classes):
        path = file_path + '\\'+name
        for file in os.listdir(path):
            image_list.append(path + '\\' + file)
            label_list.append(index)
        print(name,'ok')
    #打乱顺序
    temp = np.array([image_list,label_list])
    temp = temp.transpose()
    np.random.shuffle(temp)

    #打乱之后分出测试集和训练集
    image_data = list(temp[:,0])
    image_label = list(temp[:,1])

    n_sample = len(image_label)
    n_val = int(math.ceil(n_sample * ratdio))
    n_train = n_sample - n_val

    train_images = image_data[0:n_train]
    train_labels = image_label[0:n_train]
    train_labels = [int(float(i)) for i in train_labels]
    val_images = image_data[n_train:-1]
    val_labels = image_label[n_train:-1]
    val_labels = [int(float(i)) for i in val_labels]

    return train_images,train_labels,val_images,val_labels

def get_batch(image,label,image_W,image_H,batch_size,capacity):
    #统一数据类型
    image = tf.cast(image,tf.string)
    label = tf.cast(label,tf.int32)
    # make an input queue
    input_queue = tf.train.slice_input_producer([image, label])
    label = input_queue[1]
    image_contents = tf.read_file(input_queue[0])  # read img from a queue
    # 将图像解码,不同类型的图像不能混在一起,要么只用jpeg,要么只用png等。
    image = tf.image.decode_jpeg(image_contents, channels=3)

    # 数据预处理,对图像进行旋转、缩放、裁剪、归一化等操作,让计算出的模型更健壮。
    image = tf.image.resize_image_with_crop_or_pad(image, image_W, image_H)
    image = tf.image.per_image_standardization(image)

    # image_batch: 4D tensor [batch_size, width, height, 3],dtype=tf.float32
    # label_batch: 1D tensor [batch_size], dtype=tf.int32
    image_batch, label_batch = tf.train.batch([image, label],
                                              batch_size=batch_size,
                                              num_threads=32,
                                              capacity=capacity)
    # 重新排列label,行数为[batch_size]
    label_batch = tf.reshape(label_batch, [batch_size])
    image_batch = tf.cast(image_batch, tf.float32)
    return image_batch, label_batch


对输入到CNN的图片数据的处理。

连载:https://blog.csdn.net/qq_28821995/article/details/83587530

  https://blog.csdn.net/qq_28821995/article/details/83587802

参考:https://blog.csdn.net/ywx1832990/article/details/78610231

猜你喜欢

转载自blog.csdn.net/qq_28821995/article/details/83587032