vgg16对猫狗分类

版权声明:本文为pureszgd原创文章,未经允许不得转载, 要转载请评论留言! https://blog.csdn.net/pureszgd/article/details/83302528
from keras.models import Sequential
from keras.layers import Conv2D, MaxPool2D, Activation, Dropout, Flatten, Dense
from keras.optimizers import SGD
from keras.preprocessing.image import ImageDataGenerator, img_to_array, load_img
from keras.applications.vgg16 import VGG16
from keras.models import load_model
import numpy as np

vgg16_model = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))

# estiblish whole connect layer
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(2, activation='softmax'))

model = Sequential()
model.add(vgg16_model)
model.add(top_model)

train_datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    rescale=1/255,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

test_datagen = ImageDataGenerator(
    rescale=1/255
)

batch_size = 32

# create train data
train_generator = train_datagen.flow_from_directory(
    'train',
    target_size=(150, 150),
    batch_size=batch_size
)

# create test data
test_generator = test_datagen.flow_from_directory(
    'test',
    target_size=(150, 150),
    batch_size=batch_size
)

print train_generator.class_indices

# define optimizer, value function, calculate accuracy
model.compile(optimizer=SGD(lr=1e-4, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(train_generator, epochs=20, validation_data=test_generator, steps_per_epoch=150/batch_size, validation_steps=1)
model.save('model_vgg16.h5')

label = np.array(['cat', 'dog'])
model = load_model('model_vgg16.h5')

image = load_img('test/cat/1.jpg')
image = image.resize((150, 150))
image = img_to_array(image)
image = image / 255
image = np.expand_dims(image, 0)
print image.shape

print label[model.predict_classes(image)]
Found 400 images belonging to 2 classes.
Found 200 images belonging to 2 classes.
{'dog': 1, 'cat': 0}
(1, 150, 150, 3)
['dog']

猜你喜欢

转载自blog.csdn.net/pureszgd/article/details/83302528