Keras é uma biblioteca de aprendizado profundo baseada em Python, que encapsula uma variedade de modelos pré-treinados com base no treinamento do ImageNet, o que é muito adequado para o aprendizado de transferência.
Existem muitos tutoriais sobre o método de instalação do keras.O autor instala diretamente o keras, o que é conveniente e rápido.
Em geral, existem três maneiras mais fáceis de realizar o aprendizado por transferência: uma é usar diretamente o modelo pré-treinado; a segunda é ajustar os parâmetros com base no modelo pré-treinado; a terceira é treinar a rede neural. A camada superior (ou seja, a camada de classificação) é removida e a parte restante é considerada um extrator de recursos para extrair recursos de dados. O código para extrair recursos de imagem usando rede neural pré-treinada é o seguinte (tome VGG19 como exemplo); se você precisar chamar outros modelos, precisará alterar apenas o nome do modelo correspondente e, finalmente, salvar os recursos de camada totalmente conectados de cada imagem como um txt O arquivo é armazenado no caminho especificado para operações subseqüentes, o que é muito conveniente.
# -*- coding: utf-8 -*-
"""
Created on Mon Nov 12 21:35:24 2018
@author: 13260
"""
import os
from keras.applications.vgg19 import VGG19
from keras.preprocessing import image
from keras.applications.vgg19 import preprocess_input
from keras.models import Model
import numpy as np
def feature_extraction(filename,save_path):
model = Model(inputs=base_model.input, outputs=base_model.get_layer('fc2').output)
img_path = filename
img = image.load_img(img_path, target_size=(224, 224))
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
fc2 = model.predict(x) // 获取VGG19全连接层特征
np.savetxt(save_path +'.txt',fc2,fmt='%s') // 保存特征文件
def read_image(rootdir,save_path):
list = os.listdir(rootdir) #列出文件夹下所有的目录与文件
# print(list)
# files = []
for i in range(0,len(list)):
path = os.path.join(rootdir,list[i])
# print(path)
# subFiles = []
for file in os.listdir(path):
# subFiles.append(file)
savePath = os.path.join(save_path,file[:-4])
#print(file)
filename = os.path.join(path,file)
feature_extraction(filename,savePath)
print("successfully saved "+ file[:-4] +".txt !")
if __name__ == '__main__':
base_model = VGG19(weights='imagenet', include_top=True) //加载VGG19模型及参数
print("Model has been onload !")
rootdir = 'F:/shiyan/TensorFlow/retrain/data/train' //图片路径
save_path = "F:/python/VGG19_feature" // 提取特征文件保存路径
read_image(rootdir,save_path)
print("work has been done !")