05 CNN 猴子类别检测

一、数据集下载

kaggle数据集[10 monkey]

二、数据集准备

2.1 指定路径

from tensorflow import keras
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


train_dir = '/newdisk/darren_pty/CNN/ten_monkey/training/'
valid_dir = '/newdisk/darren_pty/CNN/ten_monkey/validation/'
label_file = '/newdisk/darren_pty/CNN/ten_monkey/monkey_labels.txt'


labels = pd.read_csv(label_file, header=0)
print(labels)

2.2 数据增强

# 图片数据生成器  数据增加
train_datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale = 1. / 255,  #jpg 0-255转变为 0-1
    rotation_range = 40,  #图片翻转
    width_shift_range = 0.2,  # 移动
    height_shift_range = 0.2, # 移动
    shear_range = 0.2, #裁剪
    zoom_range = 0.2, #缩放比例
    horizontal_flip = True,  #翻转
    vertical_flip = True,
    fill_mode = 'nearest' #填充模式
)

三、从数据集中生成数据

height = 128
width = 128
channels = 3
batch_size = 32
num_classes = 10

train_generator = train_datagen.flow_from_directory(train_dir,
                                 target_size = (height, width),
                                 batch_size = batch_size,
                                 shuffle = True,
                                 seed = 7,
                                 class_mode = 'categorical')

valid_datagen = keras.preprocessing.image.ImageDataGenerator(
    rescale = 1. / 255
)
valid_generator = valid_datagen.flow_from_directory(valid_dir,
                                 target_size = (height, width),
                                 batch_size = batch_size,
                                 shuffle = True,
                                 seed = 7,
                                 class_mode = 'categorical')
print(train_generator.samples)
print(valid_generator.samples)

Found 1098 images belonging to 10 classes.
Found 272 images belonging to 10 classes.
1098
272

四、模型

train_num = train_generator.samples
valid_num = valid_generator.samples

x, y = train_generator.next()
print(x.shape, y.shape)
print(y)


model = keras.models.Sequential()
# 卷积
model.add(keras.layers.Conv2D(filters = 32,
                              kernel_size = 3,
                              padding = 'same',
                              activation='relu',
                              # batch_size, height, width, channels
                              input_shape=(128, 128, 3)))

model.add(keras.layers.Conv2D(filters = 32,
                              kernel_size = 3,
                              padding = 'same',
                              activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D()) #

model.add(keras.layers.Conv2D(filters = 64,
                              kernel_size = 3,
                              padding = 'same',
                              activation='relu'))
model.add(keras.layers.Conv2D(filters = 64,
                              kernel_size = 3,
                              padding = 'same',
                              activation='relu'))
# 池化
model.add(keras.layers.MaxPool2D())
model.add(keras.layers.Conv2D(filters = 128,
                              kernel_size = 3,
                              padding = 'same',
                              activation='relu'))
model.add(keras.layers.Conv2D(filters = 128,
                              kernel_size = 3,
                              padding = 'same',
                              activation='relu'))
# 池化, 向下取整
model.add(keras.layers.MaxPooling2D())

model.add(keras.layers.Flatten())
model.add(keras.layers.Dense(512, activation='relu'))
model.add(keras.layers.Dense(256, activation='relu'))
model.add(keras.layers.Dense(10, activation='softmax'))

model.compile(loss='categorical_crossentropy',
             optimizer='adam',
             metrics=['accuracy'])

print(model.summary())

五、训练

history = model.fit(train_generator,
                    steps_per_epoch = train_num // batch_size,
                    epochs = 10,
                    validation_data = valid_generator,
                    validation_steps = valid_num // batch_size)

扫描二维码关注公众号,回复: 16515031 查看本文章

猜你喜欢

转载自blog.csdn.net/peng_258/article/details/132735525
CNN