Sistema de reconocimiento de mascotas basado en TensorFlow2 (rastreador, entrenamiento y ajuste de modelos, implementación de modelos)

Tabla de contenido

entorno de desarrollo

0 preparación de proyecto

1 preparación del conjunto de datos

2 Preprocesamiento de datos

3 Construye el modelo

4 Entrenamiento y verificación del modelo

5 Implementación del modelo

6 dirección del proyecto

entorno de desarrollo

Autor: Duzhouyyds
Hora: 25 de agosto de 2023
Herramientas de desarrollo integradas: PyCharm Professional 2021.1
Entorno de desarrollo integrado: Python 3.10.6
Bibliotecas de terceros: tensorflow-gpu==2.10.0, cv2==4.7.0, gevent, functools, logging , solicitudes, sistema operativo, gradiente, matplotlib, aleatorio

0 preparación de proyecto

        Esta parte establece principalmente algunos hiperparámetros en el proyecto para que los lectores puedan modificar estos hiperparámetros de acuerdo con sus propias condiciones y aún así ejecutarse normalmente.

# -*- coding: utf-8 -*-
# @File: settings.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25

# ##########爬虫############

# 图片类别和搜索关键词的映射关系
IMAGE_CLASS_KEYWORD_MAP = {
    'cat': '宠物猫',
    'dog': '宠物狗',
    'mouse': '宠物鼠',
    'rabbit': '宠物兔'
}
# 图片保存根目录
IMAGES_ROOT = './images'
# 爬虫每个类别下载多少页图片
SPIDER_DOWNLOAD_PAGES = 20

# #########数据###########

# 每个类别选取的图片数量
SAMPLES_PER_CLASS = 305
# 参与训练的类别
CLASSES = ['cat', 'dog', 'mouse', 'rabbit']
# 参与训练的类别数量
CLASS_NUM = len(CLASSES)
# 类别->编号的映射
CLASS_CODE_MAP = {
    'cat': 0,
    'dog': 1,
    'mouse': 2,
    'rabbit': 3
}
# 编号->类别的映射
CODE_CLASS_MAP = {
    0: '猫',
    1: '狗',
    2: '鼠',
    3: '兔'
}
# 随机数种子
RANDOM_SEED = 13  # 四个类别时样本较为均衡的随机数种子
# RANDOM_SEED = 19  # 三个类别时样本较为均衡的随机数种子
# 训练集比例
TRAIN_DATASET = 0.6
# 开发集比例
DEV_DATASET = 0.2
# 测试集比例
TEST_DATASET = 0.2
# mini_batch大小
BATCH_SIZE = 16
# imagenet数据集均值
IMAGE_MEAN = [0.485, 0.456, 0.406]
# imagenet数据集标准差
IMAGE_STD = [0.299, 0.224, 0.225]

# #########训练#########

# 学习率
LEARNING_RATE = 0.001
# 训练epoch数
TRAIN_EPOCHS = 30
# 保存训练模型的路径
MODEL_PATH = './model.h5'

1 preparación del conjunto de datos

        Este artículo no utiliza ningún conjunto de datos públicos para completar esta tarea, sino que utiliza un rastreador web para rastrear los materiales del conjunto de datos requeridos de Internet y luego los filtra manualmente para formar el conjunto de datos final para capacitación, verificación y prueba. .

        Para los rastreadores, la elección del motor de búsqueda es muy importante. Actualmente, sólo existen dos motores de búsqueda de uso común: Google y Baidu. Utilicé Google y Baidu para realizar búsquedas de imágenes respectivamente y descubrí que los resultados de búsqueda de Baidu eran mucho menos precisos que los de Google, así que elegí Google. Por lo tanto, mi código de rastreo se escribió basándose en Google. Para ejecutar mi código de rastreo, su red necesita poder acceder a Google. . Si su red no puede acceder a Google, puede considerar implementar usted mismo un programa de rastreo basado en Baidu. La lógica es la misma.

        Como quería que el proyecto fuera más liviano, no utilicé el marco scrapy. El rastreador se implementa mediante solicitudes + beautifulsoup4 y la concurrencia se implementa mediante gevent.

# -*- coding: utf-8 -*-
# @File: spider.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25

from gevent import monkey

monkey.patch_all()  # 使整个程序能够利用gevent的协程特性
import functools
import logging
import os
from bs4 import BeautifulSoup
from gevent.pool import Pool
import requests
import settings

# 设置日志输出格式
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
                    level=logging.INFO)

# 搜索关键词字典
keywords_map = settings.IMAGE_CLASS_KEYWORD_MAP

# 图片保存根目录
images_root = settings.IMAGES_ROOT
# 每个类别下载多少页图片
download_pages = settings.SPIDER_DOWNLOAD_PAGES
# 图片编号字典,每种图片都从0开始编号,然后递增
images_index_map = dict(zip(keywords_map.keys(), [0 for _ in keywords_map]))
# 图片去重器
duplication_filter = set()

# 请求头
headers = {
    'accept-encoding': 'gzip, deflate, br',
    'accept-language': 'zh-CN,zh;q=0.9',
    'user-agent': 'Mozilla/5.0 (Linux; Android 4.0.4; Galaxy Nexus Build/IMM76B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/46.0.2490.76 Mobile Safari/537.36',
    'accept': '*/*',
    'referer': 'https://www.google.com/',
    'authority': 'www.google.com',
}


# 重试装饰器
def try_again_while_except(max_times=3):
    """
    当出现异常时,自动重试。
    连续失败max_times次后放弃。
    """

    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            error_cnt = 0
            error_msg = ''
            while error_cnt < max_times:
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    error_msg = str(e)
                    error_cnt += 1
            if error_msg:
                logging.error(error_msg)

        return wrapper

    return decorator


@try_again_while_except()
def download_image(session, image_url, image_class):
    """
    从给定的url中下载图片,并保存到指定路径
    """
    # 下载图片
    resp = session.get(image_url, timeout=20)
    # 检查图片是否下载成功
    if resp.status_code != 200:
        raise Exception('Response Status Code {}!'.format(resp.status_code))
    # 分配一个图片编号
    image_index = images_index_map.get(image_class, 0)
    # 更新待分配编号
    images_index_map[image_class] = image_index + 1
    # 拼接图片路径
    image_path = os.path.join(images_root, image_class, '{}.jpg'.format(image_index))
    # 保存图片
    with open(image_path, 'wb') as f:
        f.write(resp.content)
    # 成功写入了一张图片
    return True


@try_again_while_except()
def get_and_analysis_google_search_page(session, page, image_class, keyword):
    """
    使用google进行搜索,下载搜索结果页面,解析其中的图片地址,并对有效图片进一步发起请求
    """
    logging.info('Class:{} Page:{} Processing...'.format(image_class, page + 1))
    # 记录从本页成功下载的图片数量
    downloaded_cnt = 0
    # 构建请求参数
    params = (
        ('q', keyword),  # 查询关键词
        ('tbm', 'isch'),  # 搜索媒体类型:图片
        ('async', '_id:islrg_c,_fmt:html'),  # 使用异步模式
        ('asearch', 'ichunklite'),  # 使用高级搜索
        ('start', str(page * 100)),  # Google每页大概显示100张图片
        ('ijn', str(page)),  # 搜索结果的页面索引
    )
    # 进行搜索
    resp = requests.get('https://www.google.com/search', params=params, timeout=20)
    # 解析搜索结果
    bsobj = BeautifulSoup(resp.content, 'lxml')
    divs = bsobj.find_all('div', {'class': 'islrtb isv-r'})
    for div in divs:
        image_url = div.get('data-ou')
        # 只有当图片以'.jpg','.jpeg','.png'结尾时才下载图片
        if image_url.endswith('.jpg') or image_url.endswith('.jpeg') or image_url.endswith('.png'):
            # 过滤掉相同图片
            if image_url not in duplication_filter:
                # 使用去重器记录
                duplication_filter.add(image_url)
                # 下载图片
                flag = download_image(session, image_url, image_class)
                if flag:
                    downloaded_cnt += 1
    logging.info('Class:{} Page:{} Done. {} images downloaded.'.format(image_class, page + 1, downloaded_cnt))


def search_with_google(image_class, keyword):
    """
    通过google下载数据集
    """
    # 创建session对象
    session = requests.session()
    session.headers.update(headers)
    # 每个类别下载20页数据
    for page in range(download_pages):
        get_and_analysis_google_search_page(session, page, image_class, keyword)


def run():
    # 首先,创建数据文件夹
    if not os.path.exists(images_root):
        os.mkdir(images_root)
    for sub_images_dir in keywords_map.keys():
        # 对于每个图片类别都创建一个单独的文件夹保存
        sub_path = os.path.join(images_root, sub_images_dir)
        if not os.path.exists(sub_path):
            os.mkdir(sub_path)
    # 开始下载,这里使用gevent的协程池进行并发
    pool = Pool(len(keywords_map))
    for image_class, keyword in keywords_map.items():
        pool.spawn(search_with_google, image_class, keyword)
    pool.join()


if __name__ == '__main__':
    run()

        El rastreador utiliza Google para búsquedas de imágenes, busca en 20 páginas para cada mascota y descarga todas las imágenes que contienen. Cuando el rastreador termine de ejecutarse, habrá una imagescarpeta adicional debajo del proyecto. Haga clic en ella y habrá cuatro subcarpetas, a saber cat, dog, mouse, rabbit. Cada subcarpeta contiene imágenes de mascotas de la categoría correspondiente.

        Entre ellos, hay más de 580 imágenes de gatos, más de 570 imágenes de perros, más de 390 imágenes de ratas y más de 480 imágenes de conejos. Se necesitaron unos veinte minutos para examinar todas las imágenes rastreadas y eliminar aquellas que no cumplían los requisitos. Tenga en cuenta que este paso es obligatorio y debe tomarse en serio. (Si este paso se realiza bien, la precisión del modelo final se puede aumentar entre 8 y 10 puntos porcentuales, como lo ha experimentado personalmente el blogger)

        Después de una ronda de proyección, quedaron la cantidad de imágenes:

mascota Número de fotos
Gato 435
perro 468
ratón 305
conejo 434

        Teniendo en cuenta el problema de equilibrar muestras de cada categoría, no es más que un sobremuestreo y un submuestreo. Debido a que se trata de datos de imagen, la mejora de datos también se puede utilizar para generar algunas imágenes para categorías con una pequeña cantidad de imágenes para equilibrar la cantidad de muestras. Pero por las siguientes razones, submuestré directamente, es decir, solo se seleccionaron 305 muestras para cada categoría:

        Si utiliza el aumento de datos, debe regenerar un conjunto de datos basado en la imagen original. Después de usar la mejora de datos, la cantidad de muestras es relativamente grande y no se pueden leer en la memoria al mismo tiempo. Solo puede escribir un generador y leerlo desde el disco duro en tiempo real al procesar qué parte. Esta desventaja sigue siendo obvia: la lectura frecuente del disco duro reducirá la velocidad de entrenamiento.

        Por supuesto, es concebible que al usar la mejora de datos (aquí, la mejora de datos se puede usar como método de sobremuestreo) para aumentar el número de muestras de datos a 468, el efecto de entrenamiento definitivamente será mejor, pero no sé cómo. mucho mejor será. En comparación con la solución que elegí, los lectores pueden implementarla ellos mismos si están interesados.

2 Preprocesamiento de datos

        Dado que el formato de entrada recibido por muchos modelos clásicos es (Ninguno, 224, 224, 3), debido a que tenemos menos muestras, inevitablemente necesitamos usar el aprendizaje por transferencia, por lo que nuestro formato de datos es consistente con el modelo clásico y también usamos ( Ninguno, 224,224,3), el siguiente es el proceso de preprocesamiento:

# -*- coding: utf-8 -*-
# @File: data.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25

import os
import random
import tensorflow as tf
import settings

# 每个类别选取的图片数量
samples_per_class = settings.SAMPLES_PER_CLASS
# 图片根目录
images_root = settings.IMAGES_ROOT
# 类别->编码的映射
class_code_map = settings.CLASS_CODE_MAP

# 我们准备使用经典网络在imagenet数据集上的与训练权重,所以归一化时也要使用imagenet的平均值和标准差
image_mean = tf.constant(settings.IMAGE_MEAN)
image_std = tf.constant(settings.IMAGE_STD)


def normalization(x):
    """
    对输入图片x进行归一化,返回归一化的值
    """
    return (x - image_mean) / image_std


def train_preprocess(x, y):
    """
    对训练数据进行预处理。
    注意,这里的参数x是图片的路径,不是图片本身;y是图片的标签值
    """
    # 读取图片
    x = tf.io.read_file(x)
    # 解码成张量
    x = tf.image.decode_jpeg(x, channels=3)
    # 将图片缩放到[244,244],比输入[224,224]稍大一些,方便后面数据增强
    x = tf.image.resize(x, [244, 244])
    # 随机决定是否左右镜像
    if random.choice([0, 1]):
        x = tf.image.random_flip_left_right(x)
    # 随机从x中剪裁出(224,224,3)大小的图片
    x = tf.image.random_crop(x, [224, 224, 3])
    # 读完上面的代码可以发现,这里的数据增强并不增加图片数量,一张图片经过变换后,
    # 仍然只是一张图片,跟我们前面说的增加图片数量的逻辑不太一样。
    # 这么做主要是应对我们的数据集里可能会存在相同图片的情况。

    # 将图片的像素值缩放到[0,1]之间
    x = tf.cast(x, dtype=tf.float32) / 255.
    # 归一化
    x = normalization(x)

    # 将标签转成one-hot形式
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, settings.CLASS_NUM)

    return x, y


def dev_preprocess(x, y):
    """
    对验证集和测试集进行数据预处理的方法。
    和train_preprocess的主要区别在于,不进行数据增强,以保证验证结果的稳定性。
    """
    # 读取并缩放图片
    x = tf.io.read_file(x)
    x = tf.image.decode_jpeg(x, channels=3)
    x = tf.image.resize(x, [224, 224])
    # 归一化
    x = tf.cast(x, dtype=tf.float32) / 255.
    x = normalization(x)
    # 将标签转成one-hot形式
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, settings.CLASS_NUM)

    return x, y


# (图片路径,标签)的列表
image_path_and_labels = []
# 排序,保证每次拿到的顺序都一样
sub_images_dir_list = sorted(list(os.listdir(images_root)))
# 遍历每一个子目录
for sub_images_dir in sub_images_dir_list:
    sub_path = os.path.join(images_root, sub_images_dir)
    # 如果给定路径是文件夹,并且这个类别参与训练
    if os.path.isdir(sub_path) and sub_images_dir in settings.CLASSES:
        # 获取当前类别的编码
        current_label = class_code_map.get(sub_images_dir)
        # 获取子目录下的全部图片名称
        images = sorted(list(os.listdir(sub_path)))
        # 随机打乱(排序和置随机数种子都是为了保证每次的结果都一样)
        random.seed(settings.RANDOM_SEED)
        random.shuffle(images)
        # 保留前settings.SAMPLES_PER_CLASS个
        images = images[:samples_per_class]
        # 构建(x,y)对
        for image_name in images:
            abs_image_path = os.path.join(sub_path, image_name)
            image_path_and_labels.append((abs_image_path, current_label))
# 计算各数据集样例数
total_samples = len(image_path_and_labels)  # 总样例数
train_samples = int(total_samples * settings.TRAIN_DATASET)  # 训练集样例数
dev_samples = int(total_samples * settings.DEV_DATASET)  # 开发集样例数
test_samples = total_samples - train_samples - dev_samples  # 测试集样例数
# 打乱数据集
random.seed(settings.RANDOM_SEED)
random.shuffle(image_path_and_labels)
# 将图片数据和标签数据分开,此时它们仍是一一对应的
x_data = tf.constant([img for img, label in image_path_and_labels])
y_data = tf.constant([label for img, label in image_path_and_labels])
# 开始划分数据集
# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_data[:train_samples], y_data[:train_samples]))
# 打乱顺序,数据预处理,设置批大小
train_db = train_db.shuffle(10000).map(train_preprocess).batch(settings.BATCH_SIZE)
# 开发集(验证集)
dev_db = tf.data.Dataset.from_tensor_slices(
    (x_data[train_samples:train_samples + dev_samples], y_data[train_samples:train_samples + dev_samples]))
# 数据预处理,设置批大小
dev_db = dev_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
# 测试集
test_db = tf.data.Dataset.from_tensor_slices(
    (x_data[train_samples + dev_samples:], y_data[train_samples + dev_samples:]))
# 数据预处理,设置批大小
test_db = test_db.map(dev_preprocess).batch(settings.BATCH_SIZE)

3 Construye el modelo

        Ahora que todos los datos están procesados, es hora de pensar en el modelo. En primer lugar, nuestro conjunto de datos es demasiado pequeño y, obviamente, construir nuestra propia red y entrenarla directamente no es una buena solución. Debido a que estos tipos de mascotas son bastante difíciles de distinguir, el modelo debe tener cierta complejidad para ajustarse bien a estos datos, pero nuestros datos son demasiado pequeños y el resultado final debe estar sobreajustado, por lo que consideramos migrar desde Comenzar aprendiendo .

        En general, se cree que el entrenamiento de redes neuronales convolucionales profundas es un proceso paso a paso para extraer características del conjunto de datos, desde características simples hasta características complejas. Lo que aprende el modelo entrenado es el método de extracción de características de la imagen, por lo que, en  teoría, el modelo entrenado en el conjunto de datos de imagenet  también se puede utilizar directamente para extraer características de otras imágenes, que también es la base del aprendizaje por transferencia. Naturalmente, este efecto a menudo no es tan bueno como volver a entrenar con datos nuevos, pero puede ahorrar mucho tiempo de entrenamiento y es muy útil en determinadas situaciones. Y este caso particular también incluye el que enfrentamos: el conjunto de datos para el problema real es demasiado pequeño.

        Hablando de aprendizaje por transferencia, lo primero que pensé fue en la serie VGG, así que la ejecuté una vez con VGG19. Utilice  la red VGG19 previamente entrenada en el conjunto de datos de imagenet  , elimine la capa superior completamente conectada y congele todos los parámetros para que no cambien durante el entrenamiento posterior. Luego agregue su propia capa completamente conectada y el nodo de la capa de salida final es 4, correspondiente a nuestro problema de clasificación de cuatro. Empezar a entrenar.

        El rendimiento de error del modelo en el conjunto de entrenamiento es bastante bueno, pero la precisión en el conjunto de validación es básicamente del 70+%. Evidentemente, este modelo ha sido sobreajustado.

        Entonces, me concentré en DenseNet121, que tiene solo 7 millones de parámetros. Efectivamente, después de un período de ajuste, el rendimiento del modelo mejoró significativamente, alcanzando aproximadamente el 91% en el conjunto de entrenamiento y aproximadamente el 93% en el conjunto de verificación. Para DenseNet121, este problema ya no es un sobreajuste, sino un desajuste. Es decir, el tamaño del conjunto de datos es demasiado pequeño.

# -*- coding: utf-8 -*-
# @File: models.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25

import tensorflow as tf
import settings
from tensorflow.keras.utils import plot_model


def my_densenet():
    """
    创建并返回一个基于densenet的Model对象
    """
    # 获取densenet网络,使用在imagenet上训练的参数值,移除头部的全连接网络,池化层使用max_pooling
    densenet = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', pooling='max')
    # 冻结预训练的参数,在之后的模型训练中不会改变它们
    densenet.trainable = False
    # 构建模型
    model = tf.keras.Sequential([
        # 输入层,shape为(None,224,224,3)
        tf.keras.layers.Input((224, 224, 3)),
        # 输入到DenseNet121中
        densenet,
        # 将DenseNet121的输出展平,以作为全连接层的输入
        tf.keras.layers.Flatten(),
        # 添加BN层
        tf.keras.layers.BatchNormalization(),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 第一个全连接层,激活函数relu
        tf.keras.layers.Dense(512, activation=tf.nn.relu),
        # BN层
        tf.keras.layers.BatchNormalization(),
        # 随机失活
        tf.keras.layers.Dropout(0.5),
        # 第二个全连接层,激活函数relu
        tf.keras.layers.Dense(64, activation=tf.nn.relu),
        # BN层
        tf.keras.layers.BatchNormalization(),
        # 输出层,为了保证输出结果的稳定,这里就不添加Dropout层了
        tf.keras.layers.Dense(settings.CLASS_NUM, activation=tf.nn.softmax)
    ])

    return model


if __name__ == '__main__':
    model = my_densenet()
    model.summary()
    plot_model(model, show_shapes=True, to_file='model.png', dpi=200)

 Resumen del modelo:

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 densenet121 (Functional)    (None, 1024)              7037504   
                                                                 
 flatten (Flatten)           (None, 1024)              0         
                                                                 
 batch_normalization (BatchN  (None, 1024)             4096      
 ormalization)                                                   
                                                                 
 dropout (Dropout)           (None, 1024)              0         
                                                                 
 dense (Dense)               (None, 512)               524800    
                                                                 
 batch_normalization_1 (Batc  (None, 512)              2048      
 hNormalization)                                                 
                                                                 
 dropout_1 (Dropout)         (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 64)                32832     
                                                                 
 batch_normalization_2 (Batc  (None, 64)               256       
 hNormalization)                                                 
                                                                 
 dense_2 (Dense)             (None, 4)                 260       
                                                                 
=================================================================
Total params: 7,601,796
Trainable params: 561,092
Non-trainable params: 7,040,704
_________________________________________________________________

El número total de parámetros es 7.601.796, de los cuales 561.092 son parámetros entrenables.

4 Entrenamiento y verificación del modelo

El modelo y los datos están listos y puede comenzar el entrenamiento. Escribamos un guión de entrenamiento:

# -*- coding: utf-8 -*-
# @File: train.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25

import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
from data import train_db, dev_db
import models
import settings

# 从models文件中导入模型
model = models.my_densenet()

# 创建 TensorBoard 回调对象
tensorboard_callback = TensorBoard(log_dir='logs', histogram_freq=1, write_graph=True, write_images=True)

# 配置优化器、损失函数、以及监控指标
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,
              metrics=['accuracy'])

# 在每个epoch结束后尝试保存模型参数,只有当前参数的val_accuracy比之前保存的更优时,才会覆盖掉之前保存的参数
model_check_point = ModelCheckpoint(filepath=settings.MODEL_PATH, monitor='val_accuracy',
                                    save_best_only=True)

# 创建早停回调对象
early_stopping = EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)

# 创建学习率减少回调对象
lr_decay = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, min_lr=1e-6)

# 使用高级接口进行训练
model.fit(train_db, epochs=settings.TRAIN_EPOCHS, validation_data=dev_db,
          callbacks=[model_check_point, early_stopping, lr_decay, tensorboard_callback])

        Ahora podemos ejecutar el script para entrenamiento y los parámetros óptimos se guardarán en formato settings.MODEL_PATH. Una vez completada la capacitación, debemos llamar al siguiente script de verificación para verificar el rendimiento del modelo en el conjunto de verificación y el conjunto de prueba: 

# -*- coding: utf-8 -*-
# @File    : eval.py
# @Author  : 嘟粥yyds
# @Time    : 2023/08/25

import tensorflow as tf
from data import dev_db, test_db
from models import my_densenet
import settings

# 创建模型
model = my_densenet()
# 加载参数
model.load_weights(settings.MODEL_PATH)
# 因为想用tf.keras的高级接口做验证,所以还是需要编译模型
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,
              metrics=['accuracy'])
# 验证集accuracy
print('dev', model.evaluate(dev_db))
# 测试集accuracy
print('test', model.evaluate(test_db))

# 查看识别错误的数据
for x, y in test_db:
    y_pred = model(x)
    y_pred = tf.argmax(y_pred, axis=1).numpy()
    y_true = tf.argmax(y, axis=1).numpy()
    batch_size = y_pred.shape[0]
    for i in range(batch_size):
        if y_pred[i] != y_true[i]:
            print('{} 被错误识别成 {}!'.format(settings.CODE_CLASS_MAP[y_true[i]], settings.CODE_CLASS_MAP[y_pred[i]]))
16/16 [==============================] - 9s 99ms/step - loss: 0.1439 - accuracy: 0.9713
dev [0.1438767910003662, 0.9713114500045776]
16/16 [==============================] - 1s 85ms/step - loss: 0.1606 - accuracy: 0.9549
test [0.16057191789150238, 0.9549180269241333]
猫 被错误识别成 兔!
猫 被错误识别成 鼠!
猫 被错误识别成 狗!
鼠 被错误识别成 兔!
猫 被错误识别成 狗!
兔 被错误识别成 猫!
兔 被错误识别成 猫!
猫 被错误识别成 兔!
狗 被错误识别成 鼠!
猫 被错误识别成 兔!
狗 被错误识别成 兔!

        Se puede ver que la precisión del modelo en el conjunto de verificación es del 97,13% y la precisión del modelo en el conjunto de prueba es del 95,49%, lo que ha alcanzado mis expectativas, después de todo, los datos utilizados son realmente muy pequeños.

5 Implementación del modelo

        La implementación del modelo de este proyecto todavía utiliza Gradio para la implementación, y sus ventajas son evidentes: la conveniencia.

import gradio as gr
import tensorflow as tf
import settings
from models import my_densenet
import matplotlib as mpl
mpl.use('TkAgg')


# 导入模型
model = my_densenet()
# 加载训练好的参数
model.load_weights(settings.MODEL_PATH)


def classify_pet_image(input_image):
    """
    宠物图片分类接口,上传一张图片,返回此图片上的宠物是哪种类别,概率多少
    """
    # 进行数据预处理
    # x = tf.image.decode_image(input_image, channels=3)
    x = tf.convert_to_tensor(input_image)
    x = tf.image.resize(x, (224, 224))
    x = x / 255.
    x = (x - tf.constant(settings.IMAGE_MEAN)) / tf.constant(settings.IMAGE_STD)
    x = tf.reshape(x, (1, 224, 224, 3))
    # 预测
    y_pred = model(x)
    pet_cls_code = tf.argmax(y_pred, axis=1).numpy()[0]
    pet_cls_prob = float(y_pred.numpy()[0][pet_cls_code])
    pet_cls_prob = '{}%'.format(int(pet_cls_prob * 100))
    pet_class = settings.CODE_CLASS_MAP.get(pet_cls_code)
    # 格式化输出为纯文本
    output_text = "宠物类别:{}  \n概率:{}".format(pet_class, pet_cls_prob)

    return output_text


gr.close_all()
demo = gr.Interface(fn=classify_pet_image,
          inputs=[gr.Image(label="Upload image")],
          outputs=[gr.Textbox(label="识别结果")],
          title="宠物识别Demo",
          description="Classify your pet!",
          allow_flagging="never"
                   )

demo.launch(share=True, debug=True, server_port=10055)

  

6 dirección del proyecto

Github: GitHub - 0911duzhou/OpenCV-Pet_Classifer: sistema de reconocimiento de mascotas basado en TensorFlow2 (rastreador, entrenamiento y ajuste de modelos, implementación de modelos)

 Si no puede acceder a Github, también puede descargarlo desde los recursos de la página de inicio del blogger.

Supongo que te gusta

Origin blog.csdn.net/zzp20031120/article/details/132496435
Recomendado
Clasificación