Ajuste de combate de rede pré-treinamento

Usar a rede de pré-treinamento para o ajuste fino é realmente usar o estado atual do parâmetro da rede de pré-treinamento como o estado de início do treinamento e continuar o treinamento com seus próprios dados. Como a rede de pré-treinamento foi treinada por muitas outras rodadas, não há necessidade de ajustá-lo. Continue treinando por muitas rodadas.

# - * - codificação: utf-8 - * - 
"" " 
Criado em quinta-feira, 29 de novembro às 14:26:37 2018 

@author: 13260 
" "" 
# - * - codificação: utf-8 - * - 
importa os 
de keras. utils importam plot_model 
de keras.applications.resnet50 importam ResNet50 
de keras.applications.vgg19 importam VGG19 
de keras.applications.inception_v3 import InceptionV3 
de keras.layers importam Dense, Flatten, GlobalAveragePooling2D 
de keras.models import Model, load_model 
de keras.optimizers SGD 
de keras.preprocessing.image import ImageDataGenerator 
import matplotlib.pyplot como 
 
classe plt PowerTransferMode:  
    # 数据 准备
    def DataGen (self, dir_path, img_row, img_col, batch_size, is_train):
        if is_train: 
            datagen = ImageDataGenerator (rescale = 1. / 255, 
                zoom_range = 0.25, rotation_range = 15., 
                channel_shift_range = 25., width_shift_range = 0.02, height_shift_range = 0.02, 
                horizontal_flip = True, fill_mode = 'constant') 
        else: 
            datagen = ImageDataGenerator (rescale = 1. / 255) 
 
        generator = datagen.flow_from_directory ( 
            dir_path, target_size = (img_row, img_col), 
            batch_size = batch_size, 
            # class_mode = 'binary', 
            shuffle = is_train) 
 
        gerador de retorno 
 
    #ResNet 模型 
    ResNet50_model (self, lr = 0,005, decaimento = 1e-6, momento = 0,9, nb_classes = 45, img_rows = 197, img_cols = 197, RGB = True, is_plot_model = False): 
        color = 3 se RGB mais 1 
        base_model = ResNet50 (pesos = 'imagenet', include_top = False, pooling = None, input_shape = (img_rows, img_cols, color), 
                              classes = nb_classes) 
 
        #Congelar todas as camadas do base_model, para que você possa obter corretamente o recurso de gargalo 
        da camada no base_model. layers: 
            layer.trainable = False 
 
        x = base_model.output # Adicione 
        sua própria camada de classificação de link completo 
        x = Achatar () (x) 
        #x = GlobalAveragePooling2D () (x) 
        #x = Denso (1024, ativação = 'relu') x)
        previsões = modelo Denso (nb_classes, ativação = 'softmax') (x) 
        color = 3 se RGB mais 1
  
        # Modelo de treinamento
        model = Model (entradas = base_model.input, saídas = previsões) 
        sgd = SGD (lr = lr, decaimento = decaimento, momento = momento, nesterov = True) 
        model.compile (perda = 'categorical_crossentropy', otimizador = sgd, métricas = ['precisão']) 
 
        # 模型 is 
        se is_plot_model: 
            plot_model (model, to_file = 'resnet50_model.png', show_shapes = True) 
 
        modelo de retorno 
 
 
    #VGG 模型
    def VGG19_model (self, lr = 0,005, decaimento = 1e-6, momento = 0.9, nb_classes = 45, img_rows = 197, img_cols = 197, RGB = True, is_plot_model = False): 
        base_model = VGG19 (pesos = 'imagenet', include_top = False, conjunto = Nenhum, input_shape = (img_rows, img_cols, color) , 
                              classes = nb_classes)
 
        # Congele todas as camadas do base_model, para que você possa obter os recursos de gargalo corretamente 
        na camada em base_model.layers: 
            layer.trainable = False 
 
        x = base_model.output #Adicione 
        sua própria camada de classificação de link completo 
        x = GlobalAveragePooling2D () (x) 
        x = Denso ( 1024, ativação = 'relu') (x) 
        previsões = Denso (nb_classes, ativação = 'softmax') (x) 
        # 
 
        Modelo de modelo de treinamento = Modelo (entradas (= base_model.input, saídas = previsões)  
        sgd = SGD (lr = lr, decaimento = decaimento, momento = momento, nesterov = True)
        model.compile (loss = ' categorical_crossentropy ', otimizador = sgd, métricas = [' precisão ']) 
 
        # plot 
        se is_plot_model: 
            plot_model (modelo, to_file =' vgg19_model.png ', show_shapes = True) 
 
        modelo de retorno
 
    # InceptionV3 模型
    def InceptionV3_model (self, lr = 0.005, decaimento = 1e-6, momento = 0.9, nb_classes = 2, img_rows = 197, img_cols = 197, RGB = True, 
                    is_plot_model = False): 
        color = 3 se RGB mais 1 
        base_model = InceptionV3 (pesos = 'imagenet', include_top = False, pooling = None, 
                           input_shape = (img_rows, img_cols, color), 
                           classes = nb_classes) 
  
        # Congele todas as camadas do base_model para que você possa obter os recursos de gargalo corretamente
        para a camada em base_model.layers: 
            layer.trainable = False 
 
        x = base_model.output 
        #层 自己 的 全 链接
        层 x = GlobalAveragePooling2D () (x) 
        x = Denso (1024, ativação = 'relu') (x)
        predições = Denso (nb_classes, ativação = 'softmax') (x) 
 
        # 训练 模型
        model = Modelo (entradas = base_model.input, saídas = previsões) 
        sgd = SGD (lr = lr, decaimento = decaimento, momento = momento, nesterov = True) 
        model.compile (loss = 'categorical_crossentropy', otimizador = sgd, métricas = ['precisão']) 
 
        # 
        绘图 se is_plot_model: 
            plot_model (model, to_file = 'inception_v3_model.png', show_shapes = True) 
 
        retornar modelo 
  
    #训练模型
    def train_model ( self, model, epochs, train_generator, steps_per_epoch, validation_generator, validation_steps, model_url, is_load_model = False): 
        # 模型 模型
        se is_load_model e os.path.exists (model_url):
            model = load_model (model_url) 
 
        history_ft = model.fit_generator ( 
            train_generator, 
            steps_per_epoch = steps_per_epoch, 
            épocas = épocas, 
            validation_data = validation_generator, 
            validation_steps = validation_steps) 
        # Anúncio Salvo模型
        model.save (model_url, overwrite = True) 
        retorno history_ft 
 
    #画图
    def plot_training ( self, history): 
      acc = history.history ['acc'] 
      épocas = intervalo (len (acc)) 
      plt.plot (épocas, acc, 'b-') 
      val_acc = history.history ['val_acc']
      loss = history.history ['perda'] 
      val_loss = history.history ['val_loss'] 
      plt.plot (épocas, val_acc, 'r') 
      plt.title ('Precisão de treinamento e validação') 
      plt.figure () 
      plt. plot (épocas, perda, 'b-') 
      plt.plot (épocas, val_loss, 'r-') 
      plt.title ('Perda de treinamento e validação') 
      plt.show () 
 
 
if __name__ == '__main__': 
    image_size = 197 
    batch_size = 32 
 
    transfer = PowerTransferMode () 
 
    # 数据
    train_generator = transfer.DataGen ('F: / shiyan / TensorFlow / retrain / data / trans_train', tamanho da imagem, tamanho da imagem, tamanho do lote, Verdadeiro) 
    validation_generator = transfer.DataGen (' F : / shiyan / TensorFlow / retrain / data / test ', tamanho da imagem, tamanho da imagem, tamanho do lote, False) 
 
    # VGG19 
    model = transfer.VGG19_model (nb_classes = 45, img_rows = image_size, img_cols = image_size, is_plot_model = False)
    history_ft = transfer.train_model (modelo, 10, train_generator, 600, validation_generator, 60, 'vgg19_model_weights.h5', is_load_model = False) 
 
    # ResNet50 
    #model = transfer.ResNet50_model (nb_classes = 45, img_rows = image_size, img_ot = False) 
    #history_ft = transfer.train_model (modelo, 10, train_generator, 600, validation_generator, 60, 'resnet50_model_weights.h5', is_load_model = False) 
 
    # InceptionV3 
    #model = transfer.InceptionV3_model (nb_classes = 2, img_rows = image_size, img_cols = image_size, is_plot_model = True)
    #history_ft = transfer.train_model (modelo, 10, train_generator, 600, validation_generator, 60, 'inception_v3_model_weights.h5', is_load_model = False) 
 
    # acc 的 acc_loss 图
    # transfer.plot_training (history_ft)
Publicado 50 artigos originais · Curtidas5 · Visitas: mais de 20.000

Acho que você gosta

Origin blog.csdn.net/qq_31207499/article/details/88635775
Recomendado
Clasificación