Keras训练自定义数据集

一、读取文件夹数据

train_idx = 0
npy_idx = 0
path = './/dataset_path//'
files = os.listdir(path)
random.shuffle(files)
images = []
labels = []
for f in files: #目录下所有文件夹
    file_path = os.path.join(path, str(f)) + '//'
    for root, dirs, files in os.walk(file_path):
        for file in files:
            if os.path.splitext(file)[1] == '.png':
                train_idx = train_idx + 1
                img_path = os.path.join(file_path, str(file))
                # print('img_path={}'.format(img_path))
                img = image.load_img(img_path, target_size=image_size)
                img_array = image.img_to_array(img)
                images.append(img_array)
                labels.append(f) 
images = np.array(images)   #(num, h, w, 3)
labels = np.array(labels)   #(num, )
images /= 255
x_train, x_test, y_train, y_test = train_test_split(images, labels, test_size=0.2)  #划分训练数据、训练标签、验证数据、验证标签

二.模型构建与编译

""" 共4层卷积网、二层全连接层"""

model = Sequential()

model.add(Conv2D(32, kernel_size=(5, 5), input_shape=(img_h, img_h, 3), activation='relu', padding='same'))
model.add(MaxPool2D())
model.add(Dropout(0.3))

model.add(Conv2D(64, kernel_size=(5, 5), activation='relu', padding='same'))
model.add(MaxPool2D())
model.add(Dropout(0.3))

model.add(Conv2D(128, kernel_size=(5, 5), activation='relu', padding='same'))
model.add(MaxPool2D())
model.add(Dropout(0.5))

model.add(Conv2D(256, kernel_size=(5, 5), activation='relu', padding='same'))
model.add(MaxPool2D())
model.add(Dropout(0.5))

model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.5))

model.add(Dense(n_class, activation='softmax')) #25分类

model.summary()
model.compile(loss=loss_func, optimizer=Adam(lr=0.0003), metrics=['accuracy'])

三.数据喂入

model.fit(x_train, y_train,
      batch_size=nbatch_size,
      epochs=nepochs,
      verbose=1,
      validation_data=(x_test, y_test))

四.模型保存

yaml_string = model.to_yaml()
with open('./models/model_name.yaml', 'w') as outfile:
    outfile.write(yaml_string)
model.save_weights('./models/model_name.h5')

猜你喜欢

转载自blog.csdn.net/Harrison509/article/details/88855310