初学Keras(构建模型,训练数据)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/xiaomifanhxx/article/details/81742756

Keras是一款比较容易上手的深度学习框架,在构建模型/训练数据方面比较方便使用

1 训练数据的传输

def prepare_input_data(img_width,img_height):
    train_datagen=image.ImageDataGenerator(
        rescale=1./255,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)
    val_datagen = image.ImageDataGenerator(rescale=1./255)
    train_generator = train_datagen.flow_from_directory(
        config['Train_path'],
        target_size=(img_width, img_height),
        batch_size=int(config['Batch_size']),
        class_mode='categorical')
    validation_generator = val_datagen.flow_from_directory(
        config['Val_path'],
        target_size=(img_width, img_height),
        batch_size=int(config['Batch_size']),
        class_mode='categorical',
        shuffle=False)
    print(train_generator.class_indices)
    #print(train_generator.shape)
    #print(validation_generator.class_indices)

    return train_generator, validation_generator

2 模型构建

import keras
import keras.preprocessing import image
import keras.layers import Conv2D,MaxPooling2D,Dense,Activation,Flatten,Dropout
import keras.layers.normalization import BatchNormalization
import keras.optimizers import SGD,RMSprop,Adagrad, Adadelta, Adam, Adamax, Nadam
import numpy as np
import keras.models import Sequential
import keras.models import load_model

def get_model(config):
    model=Sequential()
    #block 1
    model.add(Conv2D(filters=32,kernel_size=(5,5),padding='valid',activation='relu',input_shape=(int(config['IMG_WIDTH']),int(config['IMG_HEIGHT']),3)))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
    #block 2
    model.add(Conv2D(filters=64,kernel_size=(3,3),padding='same',activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
    #block 3
    model.add(Conv2D(filters=128,kernel_size=(3,3),padding='same',activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
    #block 4
    model.add(Conv2D(filters=128,kernel_size=(3,3),padding='same',activation='relu'))
    model.add(BatchNormalization())
    model.add(MaxPooling2D(pool_size=(2,2),padding='valid'))
    #block 5
    model.add(Flatten())
    
    model.add(Dense(units=1024,activation='relu'))
    
    model.add(Dropout(rate=0.5))
    #model.add(Dense(units=20),activation='softmax')
    model.add(Dense(units=8, activation='softmax'))
    return model

3 加载模型

def train(config):
    model = get_model(config)
    model.compile(loss=config['Loss'],
                  optimizer=Adam(lr=float(config['learning_rate'])),
                  metrics=[config['METRICS']])
    # Train the model
    train_generator, validation_generator = prepare_input_data(int(config['IMG_WIDTH']), int(config['IMG_HEIGHT']))
    hist = LossHistory()
    
    #class_weights = class_weight_generation(config['TRAIN_DATA_PATH'])
    model.fit_generator(
        train_generator,
        steps_per_epoch=2,
        epochs=1,
        validation_data=validation_generator,
        validation_steps=1,
        verbose=1,callbacks=[hist]
    )
    print(hist.losses)
    #save_model(model, config['model_save']+'12345')

4 配置文件的解析

def get_cfg():
    cfg={}
    f=open(r'C:/Users/zjunzhao/Desktop/cfg.txt')
    lines=f.readlines()
    for line in lines:
        cfg_value=[i for i in line.strip().split(':',1)]
        #(key, value)=line.strip().split(':',1)
        cfg[cfg_value[0]]=cfg_value[1]
        #cfg[key]=value
        #print(cfg_value)
    return cfg

5 绘制Loss/Acc图像

class LossHistory(keras.callbacks.Callback):

    def on_train_begin(self, logs={}):
        self.losses = {"batch":[],"epoch":[]}
        self.accuracy = {"batch":[],"epoch":[]}
        self.val_loss = {"batch":[],"epoch":[]}
        self.val_acc = {"batch":[],"epoch":[]}
    def on_batch_end(self, batch, logs={}):
        
        self.losses["batch"].append(logs.get('loss'))
        self.accuracy["batch"].append(logs.get('acc'))
        self.val_loss["batch"].append(logs.get('val_loss'))
        self.val_acc["batch"].append(logs.get('val_acc'))
    def on_epoch_end(self,epoch,logs={}):
        
        self.losses["epoch"].append(logs.get('loss'))
        self.accuracy["epoch"].append(logs.get('acc'))
        self.val_loss["epoch"].append(logs.get('val_loss'))
        self.val_acc["epoch"].append(logs.get('val_acc'))
    def loss_plot(self, loss_type):
        iters = range(len(self.losses[loss_type]))
        plt.figure()
        # acc
        plt.plot(iters, self.accuracy[loss_type], 'r', label='train acc')
        # loss
        plt.plot(iters, self.losses[loss_type], 'g', label='train loss')
        if loss_type == 'epoch':
            # val_acc
            plt.plot(iters, self.val_acc[loss_type], 'b', label='val acc')
            # val_loss
            plt.plot(iters, self.val_loss[loss_type], 'k', label='val loss')
        plt.grid(True)
        plt.xlabel(loss_type)
        plt.ylabel('acc-loss')
        plt.legend(loc="upper right")
        plt.show()

6 python实现拷贝指定文件到指定目录

import os
import shutil
alllist=os.listdir(u"D:\\notes\\python\\资料\\")
for i in alllist:
    aa,bb=i.split(".")
    if 'python' in aa.lower():
        oldname= u"D:\\notes\\python\\资料\\"+aa+"."+bb
        newname=u"d:\\copy\\newname"+aa+"."+bb
        shutil.copyfile(oldname,newname)

参考博客:

绘制loss:https://www.cnblogs.com/jzy996492849/p/7234233.html

https://blog.csdn.net/u011037837/article/details/51593099

https://blog.csdn.net/u013381011/article/details/78911848

保存checkpoints以及kepoints detection:https://blog.csdn.net/hjimce/article/details/49095199

猜你喜欢

转载自blog.csdn.net/xiaomifanhxx/article/details/81742756