keras之读取h5文件(三)

0 前言

要求:

  • 安装keras库
  • 需要一个h5文件进行读取
  • h5文件获取方式:
    1、keras之分类数字图片(二),该文章的会生成模型及参数,建议使用这种方法,掌握来龙去脉。
    2、或者直接获取:百度网盘
    提取码:gedt
    在这里插入图片描述

1 读取h5文件

目标:

  • 熟练使用keras存读取h5文件

步骤:

  • 1、加载模型
  • 2、打印模型权重
  • 3、使用test数据集验证,加载的模型是否完整

1.1 加载模型

from datetime import datetime
from keras.models import load_model

# 如下的文件位置,根据自己电脑的文件位置更改
model = load_model('dentify_writtern_number_20210129_1209/epoch:10-loss:0.2525.h5')

# 打印加载的模型结构
print (model.summary())

model已经加载进去,我们可以看一下模型的结构
运行结果:
在这里插入图片描述

1.2 打印模型权重

for i in range(len(model.layers)):
  if len(model.layers[i].get_weights())==1:
    w = model.layers[i].get_weights()
    b=0
  if len(model.layers[i].get_weights())==2:
    w,b = model.layers[i].get_weights()
  
  print(w.shape,b.shape)
model.layers[1].weights

运行结果部分截图:
在这里插入图片描述

1.3 使用test集验证

随机网上找一个图片验证。

import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt

# 自己网上随便找一个jpg图片
img = cv.imread("8.jpg")
gray_img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
img_resize = cv.resize(gray_img,(28,28))
plt.imshow(img_resize)
plt.show()
# 以上进行图片灰度化,以及reshape操作,使其符合我们数据要求

print(gray_img.shape)
y_pred = model.predict(img_resize.reshape((1,28,28,1)))
print (y_pred)
print (np.argmax(y_pred))

运行结果:
在这里插入图片描述

2 源代码

# 1、加载模型
from datetime import datetime
from keras.models import load_model

# 如下的文件位置,根据自己电脑的文件位置更改
model = load_model('dentify_writtern_number_20210129_1209/epoch:10-loss:0.2525.h5')

# 打印加载的模型结构
print (model.summary())


# 2、打印模型权重
for i in range(len(model.layers)):
  if len(model.layers[i].get_weights())==1:
    w = model.layers[i].get_weights()
    b=0
  if len(model.layers[i].get_weights())==2:
    w,b = model.layers[i].get_weights()
  
  print(w.shape,b.shape)
model.layers[1].weights


# 3、使用test集验证
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt

# 自己网上随便找一个jpg图片
img = cv.imread("8.jpg")
gray_img = cv.cvtColor(img, cv.COLOR_RGB2GRAY)
img_resize = cv.resize(gray_img,(28,28))
plt.imshow(img_resize)
plt.show()
# 以上进行图片灰度化,以及reshape操作,使其符合我们数据要求

print(gray_img.shape)
y_pred = model.predict(img_resize.reshape((1,28,28,1)))
print (y_pred)
print (np.argmax(y_pred))

如有疑惑,以下评论区留言。力所能及,必答之。

猜你喜欢

转载自blog.csdn.net/weixin_41466575/article/details/113456345