Python y aprendizaje profundo (ocho): CNN y fashion_mnist dos

1. Descripción

Este artículo es para probar el modelo entrenado en el artículo anterior. El primero es volver a cargar el modelo entrenado, luego usar opencv para cargar la imagen y finalmente enviar la imagen cargada al modelo y mostrar el resultado.

2. Prueba modelo CNN de fashion_mnist

2.1 Importar bibliotecas relacionadas

Importe aquí la biblioteca de terceros requerida, como cv2; si no, debe descargarla usted mismo.

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist

2.2 Cargar datos y modelo

Cargue el conjunto de datos fashion_mnist y cargue el modelo entrenado.

# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')

2.3 Establecer la ruta para guardar la imagen

Guarde ciertos datos del conjunto de datos en forma de imagen, lo cual es conveniente para la visualización de la prueba.
Establezca aquí la ubicación de almacenamiento de la imagen.

# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)

Después de escribir el código anterior, debe crear una carpeta imgs en la ruta actual del código para almacenar imágenes, de la siguiente manera.
inserte la descripción de la imagen aquí

Después de ejecutar el código anterior, se puede encontrar una imagen más en el archivo imgs, de la siguiente manera (probado muchas veces a continuación).
inserte la descripción de la imagen aquí

2.4 Cargar imágenes

Use cv2 para cargar la imagen. La razón por la cual la última línea de código a continuación toma un canal es que cuando se usa la biblioteca opencv, es decir, cv2 para leer la imagen, la imagen es de tres canales y el modelo entrenado es de un solo canal. canal, por lo que se toma el canal único.

# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]

2.5 Preprocesamiento de imágenes

Preprocesar la imagen, es decir, normalizar y cambiar la forma, es facilitar la entrada de la imagen al modelo entrenado para la predicción.

# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)

2.6 Predicción de imágenes

Ingrese la imagen al modelo entrenado y haga predicciones.
El resultado pronosticado es de 10 valores de probabilidad, por lo que debe procesarse. np.argmax() es el número de serie del valor máximo del valor de probabilidad, que es el número pronosticado.

# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])

2.7 Mostrar imágenes

Muestre la imagen predicha y muestre el número predicho en la imagen.
Las siguientes 5 líneas de código son para crear la ventana, establecer el tamaño de la ventana, mostrar la imagen, mantener la imagen y borrar la memoria.

# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

3. Código completo y visualización de resultados.

A continuación se muestra el código completo y una imagen que muestra el resultado.

from tensorflow import keras
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np
# 导入tensorflow
import tensorflow as tf
# 导入keras
from tensorflow import keras
from keras.datasets import fashion_mnist
# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载fashion数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')
# 创建图片保存路径
test_file_path = os.path.join(sys.path[0], 'imgs', 'test100.png')
# 存储测试数据的任意一个
Image.fromarray(x_test[100]).save(test_file_path)
# 加载本地test.png图像
image = cv2.imread(test_file_path)
# 复制图片
test_img = image.copy()
# 将图片大小转换成(28,28)
test_img = cv2.resize(test_img, (28, 28))
# 取单通道值
test_img = test_img[:, :, 0]
# 预处理: 归一化 + reshape
new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
# 预测
y_pre_pro = recons_model.predict(new_test_img, verbose=1)
# 哪一类
class_id = np.argmax(y_pre_pro, axis=1)[0]
print('test.png的预测概率:', y_pre_pro)
print('test.png的预测概率:', y_pre_pro[0, class_id])
print('test.png的所属类别:', class_names[class_id])
text = str(class_names[class_id])
# # 显示
cv2.namedWindow('img', 0)
cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
cv2.imshow('img', image)
cv2.waitKey()
cv2.destroyAllWindows()

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
1/1 [==============================] - 0s 168ms/step
test.png的预测概率: [[2.9672831e-04 7.3040414e-05 1.4721525e-04 9.9842703e-01 4.7597905e-06
  8.9959512e-06 1.0416918e-03 8.6147125e-09 4.2549357e-07 1.2974965e-07]]
test.png的预测概率: 0.99842703
test.png的所属类别: Dress

inserte la descripción de la imagen aquí

4. El código completo y los resultados de las pruebas con varias imágenes.

Para probar más imágenes, se introduce un bucle para realizar múltiples pruebas y el efecto es mejor.

from tensorflow import keras
from keras.datasets import fashion_mnist
import skimage, os, sys, cv2
from PIL import ImageFont, Image, ImageDraw  # PIL就是pillow包(保存图像)
import numpy as np

# fashion数据集列表
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# 加载mnist数据
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
# 加载cnn_fashion.h5文件,重新生成模型对象
recons_model = keras.models.load_model('cnn_fashion.h5')

prepicture = int(input("input the number of test picture :"))
for i in range(prepicture):
    path1 = input("input the test picture path:")
    # 创建图片保存路径
    test_file_path = os.path.join(sys.path[0], 'imgs', path1)
    # 存储测试数据的任意一个
    num = int(input("input the test picture num:"))
    Image.fromarray(x_test[num]).save(test_file_path)
    # 加载本地test.png图像
    image = cv2.imread(test_file_path)
    # 复制图片
    test_img = image.copy()
    # 将图片大小转换成(28,28)
    test_img = cv2.resize(test_img, (28, 28))
    # 取单通道值
    test_img = test_img[:, :, 0]
    # 预处理: 归一化 + reshape
    new_test_img = (test_img/255.0).reshape(1, 28, 28, 1)
    # 预测
    y_pre_pro = recons_model.predict(new_test_img, verbose=1)
    # 哪一类数字
    class_id = np.argmax(y_pre_pro, axis=1)[0]
    print('test.png的预测概率:', y_pre_pro)
    print('test.png的预测概率:', y_pre_pro[0, class_id])
    print('test.png的所属类别:', class_names[class_id])
    text = str(class_names[class_id])
    # # 显示
    cv2.namedWindow('img', 0)
    cv2.resizeWindow('img', 500, 500)  # 自己设定窗口图片的大小
    cv2.imshow('img', image)
    cv2.waitKey()
    cv2.destroyAllWindows()

To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
input the number of test picture :2
input the test picture path:101.jpg
input the test picture num:1
1/1 [==============================] - 0s 145ms/step
test.png的预测概率: [[5.1000708e-05 2.9449904e-13 9.9993873e-01 5.5402721e-11 4.8696438e-06
  1.2649738e-12 5.3379590e-06 6.5959898e-17 7.1223938e-10 4.0113624e-12]]
test.png的预测概率: 0.9999387
test.png的所属类别: Pullover

inserte la descripción de la imagen aquí

input the test picture path:102.jpg
input the test picture num:2
1/1 [==============================] - 0s 21ms/step
test.png的预测概率: [[3.01315001e-10 1.00000000e+00 1.03142118e-14 8.63922683e-11
  4.10812981e-11 6.07313693e-22 2.31636132e-09 5.08595438e-25
  1.02018335e-13 8.82350167e-28]]
test.png的预测概率: 1.0
test.png的所属类别: Trouser

inserte la descripción de la imagen aquí

Supongo que te gusta

Origin blog.csdn.net/qq_47598782/article/details/131968902
Recomendado
Clasificación