目次
開発環境
著者: Duzhouyyds
時刻: 2023 年 8 月 25 日
統合開発ツール: PyCharm Professional 2021.1
統合開発環境: Python 3.10.6
サードパーティ ライブラリ: tensorflow-gpu==2.10.0、cv2==4.7.0、gevent、functools、logging 、リクエスト、OS、勾配、matplotlib、ランダム
0 プロジェクトの準備
この部分では主に、プロジェクトにいくつかのハイパー パラメータを設定します。これにより、読者はこれらのハイパー パラメータを独自の条件に応じて変更し、通常どおり実行できます。
# -*- coding: utf-8 -*-
# @File: settings.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25
# ##########爬虫############
# 图片类别和搜索关键词的映射关系
IMAGE_CLASS_KEYWORD_MAP = {
'cat': '宠物猫',
'dog': '宠物狗',
'mouse': '宠物鼠',
'rabbit': '宠物兔'
}
# 图片保存根目录
IMAGES_ROOT = './images'
# 爬虫每个类别下载多少页图片
SPIDER_DOWNLOAD_PAGES = 20
# #########数据###########
# 每个类别选取的图片数量
SAMPLES_PER_CLASS = 305
# 参与训练的类别
CLASSES = ['cat', 'dog', 'mouse', 'rabbit']
# 参与训练的类别数量
CLASS_NUM = len(CLASSES)
# 类别->编号的映射
CLASS_CODE_MAP = {
'cat': 0,
'dog': 1,
'mouse': 2,
'rabbit': 3
}
# 编号->类别的映射
CODE_CLASS_MAP = {
0: '猫',
1: '狗',
2: '鼠',
3: '兔'
}
# 随机数种子
RANDOM_SEED = 13 # 四个类别时样本较为均衡的随机数种子
# RANDOM_SEED = 19 # 三个类别时样本较为均衡的随机数种子
# 训练集比例
TRAIN_DATASET = 0.6
# 开发集比例
DEV_DATASET = 0.2
# 测试集比例
TEST_DATASET = 0.2
# mini_batch大小
BATCH_SIZE = 16
# imagenet数据集均值
IMAGE_MEAN = [0.485, 0.456, 0.406]
# imagenet数据集标准差
IMAGE_STD = [0.299, 0.224, 0.225]
# #########训练#########
# 学习率
LEARNING_RATE = 0.001
# 训练epoch数
TRAIN_EPOCHS = 30
# 保存训练模型的路径
MODEL_PATH = './model.h5'
1 データセットの準備
この記事では、このタスクを完了するために公開データ セットを使用しません。代わりに、Web クローラーを使用して必要なデータ セット素材をインターネットからクロールし、それらを手動でフィルターして、トレーニング、検証、テスト用の最終データ セットを形成します。 。
クローラーにとって、検索エンジンの選択は非常に重要です。現在、一般的に使用されている検索エンジンは Google と Baidu の 2 つだけです。Google と Baidu をそれぞれ使用して画像検索を実行しましたが、Baidu の検索結果は Google よりもはるかに精度が低いことがわかったので、Google を選択しました。そのため、クローラー コードは Google に基づいて作成されました。クローラー コードを実行するには、ネットワークが次のことを行う必要があります。 Googleにアクセスできること。ネットワークが Google にアクセスできない場合は、Baidu に基づいたクローラー プログラムを自分で実装することを検討できます。ロジックは同じです。
プロジェクトをより軽量にしたいため、Scrapy フレームワークは使用しませんでした。クローラーはrequests+beautysoup4を使用して実装され、同時実行性はgeventを使用して実装されます。
# -*- coding: utf-8 -*-
# @File: spider.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25
from gevent import monkey
monkey.patch_all() # 使整个程序能够利用gevent的协程特性
import functools
import logging
import os
from bs4 import BeautifulSoup
from gevent.pool import Pool
import requests
import settings
# 设置日志输出格式
logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
level=logging.INFO)
# 搜索关键词字典
keywords_map = settings.IMAGE_CLASS_KEYWORD_MAP
# 图片保存根目录
images_root = settings.IMAGES_ROOT
# 每个类别下载多少页图片
download_pages = settings.SPIDER_DOWNLOAD_PAGES
# 图片编号字典,每种图片都从0开始编号,然后递增
images_index_map = dict(zip(keywords_map.keys(), [0 for _ in keywords_map]))
# 图片去重器
duplication_filter = set()
# 请求头
headers = {
'accept-encoding': 'gzip, deflate, br',
'accept-language': 'zh-CN,zh;q=0.9',
'user-agent': 'Mozilla/5.0 (Linux; Android 4.0.4; Galaxy Nexus Build/IMM76B) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/46.0.2490.76 Mobile Safari/537.36',
'accept': '*/*',
'referer': 'https://www.google.com/',
'authority': 'www.google.com',
}
# 重试装饰器
def try_again_while_except(max_times=3):
"""
当出现异常时,自动重试。
连续失败max_times次后放弃。
"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
error_cnt = 0
error_msg = ''
while error_cnt < max_times:
try:
return func(*args, **kwargs)
except Exception as e:
error_msg = str(e)
error_cnt += 1
if error_msg:
logging.error(error_msg)
return wrapper
return decorator
@try_again_while_except()
def download_image(session, image_url, image_class):
"""
从给定的url中下载图片,并保存到指定路径
"""
# 下载图片
resp = session.get(image_url, timeout=20)
# 检查图片是否下载成功
if resp.status_code != 200:
raise Exception('Response Status Code {}!'.format(resp.status_code))
# 分配一个图片编号
image_index = images_index_map.get(image_class, 0)
# 更新待分配编号
images_index_map[image_class] = image_index + 1
# 拼接图片路径
image_path = os.path.join(images_root, image_class, '{}.jpg'.format(image_index))
# 保存图片
with open(image_path, 'wb') as f:
f.write(resp.content)
# 成功写入了一张图片
return True
@try_again_while_except()
def get_and_analysis_google_search_page(session, page, image_class, keyword):
"""
使用google进行搜索,下载搜索结果页面,解析其中的图片地址,并对有效图片进一步发起请求
"""
logging.info('Class:{} Page:{} Processing...'.format(image_class, page + 1))
# 记录从本页成功下载的图片数量
downloaded_cnt = 0
# 构建请求参数
params = (
('q', keyword), # 查询关键词
('tbm', 'isch'), # 搜索媒体类型:图片
('async', '_id:islrg_c,_fmt:html'), # 使用异步模式
('asearch', 'ichunklite'), # 使用高级搜索
('start', str(page * 100)), # Google每页大概显示100张图片
('ijn', str(page)), # 搜索结果的页面索引
)
# 进行搜索
resp = requests.get('https://www.google.com/search', params=params, timeout=20)
# 解析搜索结果
bsobj = BeautifulSoup(resp.content, 'lxml')
divs = bsobj.find_all('div', {'class': 'islrtb isv-r'})
for div in divs:
image_url = div.get('data-ou')
# 只有当图片以'.jpg','.jpeg','.png'结尾时才下载图片
if image_url.endswith('.jpg') or image_url.endswith('.jpeg') or image_url.endswith('.png'):
# 过滤掉相同图片
if image_url not in duplication_filter:
# 使用去重器记录
duplication_filter.add(image_url)
# 下载图片
flag = download_image(session, image_url, image_class)
if flag:
downloaded_cnt += 1
logging.info('Class:{} Page:{} Done. {} images downloaded.'.format(image_class, page + 1, downloaded_cnt))
def search_with_google(image_class, keyword):
"""
通过google下载数据集
"""
# 创建session对象
session = requests.session()
session.headers.update(headers)
# 每个类别下载20页数据
for page in range(download_pages):
get_and_analysis_google_search_page(session, page, image_class, keyword)
def run():
# 首先,创建数据文件夹
if not os.path.exists(images_root):
os.mkdir(images_root)
for sub_images_dir in keywords_map.keys():
# 对于每个图片类别都创建一个单独的文件夹保存
sub_path = os.path.join(images_root, sub_images_dir)
if not os.path.exists(sub_path):
os.mkdir(sub_path)
# 开始下载,这里使用gevent的协程池进行并发
pool = Pool(len(keywords_map))
for image_class, keyword in keywords_map.items():
pool.spawn(search_with_google, image_class, keyword)
pool.join()
if __name__ == '__main__':
run()
クローラーは画像検索に Google を使用し、ペットごとに 20 ページを検索し、そのページ内のすべての画像をダウンロードします。クローラーの実行が終了すると、プロジェクトの下に追加のフォルダーが作成されます。クリックすると、 、、、images
という 4 つのサブフォルダーが表示されます。各サブフォルダーには、対応するカテゴリのペットの写真が含まれています。cat
dog
mouse
rabbit
その中には、580 枚以上の猫の写真、570 枚以上の犬の写真、390 枚以上のネズミの写真、480 枚以上のウサギの写真があります。クロールされたすべての画像をスクリーニングし、要件を満たさない画像を削除するのに約 20 分かかりました。この手順は必須であり、真剣に取り組む必要があることに注意してください。(ブロガーが個人的に経験したように、このステップが適切に実行されれば、最終モデルの精度は 8 ~ 10 パーセント ポイント向上します)
一連の上映後に残った写真の枚数:
ペット | 写真の枚数 |
猫 | 435 |
犬 | 468 |
ねずみ | 305 |
うさぎ | 434 |
各カテゴリのサンプルのバランスの問題を考慮すると、それはオーバーサンプリングとアンダーサンプリングに他なりません。これは画像データであるため、データ拡張を使用して、サンプル数のバランスをとるために、画像数が少ないカテゴリのいくつかの画像を生成することもできます。ただし、次の理由により、直接アンダーサンプリングしました。つまり、カテゴリごとに 305 個のサンプルのみが選択されました。
データ拡張を使用する場合は、元のイメージに基づいてデータ セットを再生成する必要があります。データエンハンスメントを使用した後は、サンプル数が比較的多く、一度にメモリに読み込むことができないため、どの部分を処理するかをリアルタイムでジェネレータを書いてハードディスクから読み込むことしかできません。この欠点は依然として明らかであり、ハードディスクを頻繁に読み取るとトレーニング速度が遅くなります。
もちろん、データエンハンスメント(ここではデータエンハンスメントはオーバーサンプリングの方法として使用できます)を使用してデータサンプル数を468に増やすことで、トレーニング効果が確実に向上することは考えられますが、方法はわかりませんずっと良くなるでしょう。私が選んだソリューションと比較すると、読者は興味があれば自分で実装できます。
2 データの前処理
多くのクラシック モデルが受け取る入力形式は (None, 224, 224, 3) であるため、サンプルが少ないため必然的に転移学習を使用する必要があるため、データ形式はクラシック モデルと一致しており、また (なし、224、224、3)、前処理プロセスは次のとおりです。
# -*- coding: utf-8 -*-
# @File: data.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25
import os
import random
import tensorflow as tf
import settings
# 每个类别选取的图片数量
samples_per_class = settings.SAMPLES_PER_CLASS
# 图片根目录
images_root = settings.IMAGES_ROOT
# 类别->编码的映射
class_code_map = settings.CLASS_CODE_MAP
# 我们准备使用经典网络在imagenet数据集上的与训练权重,所以归一化时也要使用imagenet的平均值和标准差
image_mean = tf.constant(settings.IMAGE_MEAN)
image_std = tf.constant(settings.IMAGE_STD)
def normalization(x):
"""
对输入图片x进行归一化,返回归一化的值
"""
return (x - image_mean) / image_std
def train_preprocess(x, y):
"""
对训练数据进行预处理。
注意,这里的参数x是图片的路径,不是图片本身;y是图片的标签值
"""
# 读取图片
x = tf.io.read_file(x)
# 解码成张量
x = tf.image.decode_jpeg(x, channels=3)
# 将图片缩放到[244,244],比输入[224,224]稍大一些,方便后面数据增强
x = tf.image.resize(x, [244, 244])
# 随机决定是否左右镜像
if random.choice([0, 1]):
x = tf.image.random_flip_left_right(x)
# 随机从x中剪裁出(224,224,3)大小的图片
x = tf.image.random_crop(x, [224, 224, 3])
# 读完上面的代码可以发现,这里的数据增强并不增加图片数量,一张图片经过变换后,
# 仍然只是一张图片,跟我们前面说的增加图片数量的逻辑不太一样。
# 这么做主要是应对我们的数据集里可能会存在相同图片的情况。
# 将图片的像素值缩放到[0,1]之间
x = tf.cast(x, dtype=tf.float32) / 255.
# 归一化
x = normalization(x)
# 将标签转成one-hot形式
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, settings.CLASS_NUM)
return x, y
def dev_preprocess(x, y):
"""
对验证集和测试集进行数据预处理的方法。
和train_preprocess的主要区别在于,不进行数据增强,以保证验证结果的稳定性。
"""
# 读取并缩放图片
x = tf.io.read_file(x)
x = tf.image.decode_jpeg(x, channels=3)
x = tf.image.resize(x, [224, 224])
# 归一化
x = tf.cast(x, dtype=tf.float32) / 255.
x = normalization(x)
# 将标签转成one-hot形式
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, settings.CLASS_NUM)
return x, y
# (图片路径,标签)的列表
image_path_and_labels = []
# 排序,保证每次拿到的顺序都一样
sub_images_dir_list = sorted(list(os.listdir(images_root)))
# 遍历每一个子目录
for sub_images_dir in sub_images_dir_list:
sub_path = os.path.join(images_root, sub_images_dir)
# 如果给定路径是文件夹,并且这个类别参与训练
if os.path.isdir(sub_path) and sub_images_dir in settings.CLASSES:
# 获取当前类别的编码
current_label = class_code_map.get(sub_images_dir)
# 获取子目录下的全部图片名称
images = sorted(list(os.listdir(sub_path)))
# 随机打乱(排序和置随机数种子都是为了保证每次的结果都一样)
random.seed(settings.RANDOM_SEED)
random.shuffle(images)
# 保留前settings.SAMPLES_PER_CLASS个
images = images[:samples_per_class]
# 构建(x,y)对
for image_name in images:
abs_image_path = os.path.join(sub_path, image_name)
image_path_and_labels.append((abs_image_path, current_label))
# 计算各数据集样例数
total_samples = len(image_path_and_labels) # 总样例数
train_samples = int(total_samples * settings.TRAIN_DATASET) # 训练集样例数
dev_samples = int(total_samples * settings.DEV_DATASET) # 开发集样例数
test_samples = total_samples - train_samples - dev_samples # 测试集样例数
# 打乱数据集
random.seed(settings.RANDOM_SEED)
random.shuffle(image_path_and_labels)
# 将图片数据和标签数据分开,此时它们仍是一一对应的
x_data = tf.constant([img for img, label in image_path_and_labels])
y_data = tf.constant([label for img, label in image_path_and_labels])
# 开始划分数据集
# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_data[:train_samples], y_data[:train_samples]))
# 打乱顺序,数据预处理,设置批大小
train_db = train_db.shuffle(10000).map(train_preprocess).batch(settings.BATCH_SIZE)
# 开发集(验证集)
dev_db = tf.data.Dataset.from_tensor_slices(
(x_data[train_samples:train_samples + dev_samples], y_data[train_samples:train_samples + dev_samples]))
# 数据预处理,设置批大小
dev_db = dev_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
# 测试集
test_db = tf.data.Dataset.from_tensor_slices(
(x_data[train_samples + dev_samples:], y_data[train_samples + dev_samples:]))
# 数据预处理,设置批大小
test_db = test_db.map(dev_preprocess).batch(settings.BATCH_SIZE)
3 モデルを構築する
データがすべて処理されたので、次はモデルについて考えます。まず第一に、私たちのデータセットは小さすぎるため、独自のネットワークを直接構築してトレーニングするのは明らかに良い解決策ではありません。これらのタイプのペットは実際には区別するのが非常に難しいため、これらのデータを適切に適合させるにはモデルにある程度の複雑さが必要ですが、データが小さすぎるため、最終結果は過適合になるはずであるため、学習による開始からの移行を検討します。 。
一般に、ディープ畳み込みニューラル ネットワークのトレーニングは、単純な特徴から複雑な特徴まで、データ セットから特徴を抽出する段階的なプロセスであると考えられています。トレーニングされたモデルが学習するのは画像の特徴の抽出方法であるため、理論的には 、 imagenet データセットでトレーニングされたモデルを他の画像の特徴を抽出するために直接使用することもでき、これが転移学習の基礎でもあります。当然のことながら、この効果は新しいデータでの再トレーニングほど良くないことがよくありますが、トレーニング時間を大幅に節約でき、特定の状況では非常に役立ちます。そして、この特定のケースには、私たちが直面しているケースも含まれています。実際の問題に対するデータセットが小さすぎるのです。
転移学習といえばVGGシリーズが真っ先に思い浮かぶので、VGG19で一度動かしてみました。imagenet データセットで事前トレーニングされた VGG19 ネットワークを使用し 、最上位の完全に接続されたレイヤーを削除し、後続のトレーニング中に変更されないようにすべてのパラメーターをフリーズします。次に、独自の全結合層を追加すると、最終的な出力層ノードは 4 になり、4 つの分類問題に対応します。トレーニングを開始します。
トレーニング セットでのモデルのエラー パフォーマンスはかなり良好ですが、検証セットでの精度は基本的に 70% 以上です。明らかに、このモデルは過剰適合されています。
そこで、パラメータが 7M しかない DenseNet121 に注目しました。案の定、一定期間のチューニングの後、モデルのパフォーマンスは大幅に向上し、トレーニング セットでは約 91%、検証セットでは約 93% に達しました。DenseNet121 の場合、この問題は過学習ではなく、過小学習です。つまり、データセットのサイズが小さすぎます。
# -*- coding: utf-8 -*-
# @File: models.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25
import tensorflow as tf
import settings
from tensorflow.keras.utils import plot_model
def my_densenet():
"""
创建并返回一个基于densenet的Model对象
"""
# 获取densenet网络,使用在imagenet上训练的参数值,移除头部的全连接网络,池化层使用max_pooling
densenet = tf.keras.applications.DenseNet121(include_top=False, weights='imagenet', pooling='max')
# 冻结预训练的参数,在之后的模型训练中不会改变它们
densenet.trainable = False
# 构建模型
model = tf.keras.Sequential([
# 输入层,shape为(None,224,224,3)
tf.keras.layers.Input((224, 224, 3)),
# 输入到DenseNet121中
densenet,
# 将DenseNet121的输出展平,以作为全连接层的输入
tf.keras.layers.Flatten(),
# 添加BN层
tf.keras.layers.BatchNormalization(),
# 随机失活
tf.keras.layers.Dropout(0.5),
# 第一个全连接层,激活函数relu
tf.keras.layers.Dense(512, activation=tf.nn.relu),
# BN层
tf.keras.layers.BatchNormalization(),
# 随机失活
tf.keras.layers.Dropout(0.5),
# 第二个全连接层,激活函数relu
tf.keras.layers.Dense(64, activation=tf.nn.relu),
# BN层
tf.keras.layers.BatchNormalization(),
# 输出层,为了保证输出结果的稳定,这里就不添加Dropout层了
tf.keras.layers.Dense(settings.CLASS_NUM, activation=tf.nn.softmax)
])
return model
if __name__ == '__main__':
model = my_densenet()
model.summary()
plot_model(model, show_shapes=True, to_file='model.png', dpi=200)
モデルの概要:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
densenet121 (Functional) (None, 1024) 7037504
flatten (Flatten) (None, 1024) 0
batch_normalization (BatchN (None, 1024) 4096
ormalization)
dropout (Dropout) (None, 1024) 0
dense (Dense) (None, 512) 524800
batch_normalization_1 (Batc (None, 512) 2048
hNormalization)
dropout_1 (Dropout) (None, 512) 0
dense_1 (Dense) (None, 64) 32832
batch_normalization_2 (Batc (None, 64) 256
hNormalization)
dense_2 (Dense) (None, 4) 260
=================================================================
Total params: 7,601,796
Trainable params: 561,092
Non-trainable params: 7,040,704
_________________________________________________________________
パラメータの総数は 7,601,796 で、そのうち 561,092 がトレーニング可能なパラメータです。
4 モデルのトレーニングと検証
モデルとデータの準備ができたので、トレーニングを開始できます。トレーニング スクリプトを書いてみましょう。
# -*- coding: utf-8 -*-
# @File: train.py
# @Author: 嘟粥yyds
# @Time: 2023/08/25
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard
from data import train_db, dev_db
import models
import settings
# 从models文件中导入模型
model = models.my_densenet()
# 创建 TensorBoard 回调对象
tensorboard_callback = TensorBoard(log_dir='logs', histogram_freq=1, write_graph=True, write_images=True)
# 配置优化器、损失函数、以及监控指标
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,
metrics=['accuracy'])
# 在每个epoch结束后尝试保存模型参数,只有当前参数的val_accuracy比之前保存的更优时,才会覆盖掉之前保存的参数
model_check_point = ModelCheckpoint(filepath=settings.MODEL_PATH, monitor='val_accuracy',
save_best_only=True)
# 创建早停回调对象
early_stopping = EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
# 创建学习率减少回调对象
lr_decay = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, min_lr=1e-6)
# 使用高级接口进行训练
model.fit(train_db, epochs=settings.TRAIN_EPOCHS, validation_data=dev_db,
callbacks=[model_check_point, early_stopping, lr_decay, tensorboard_callback])
これで、トレーニング用のスクリプトを実行できるようになり、最適なパラメーターが に保存されますsettings.MODEL_PATH
。トレーニングが完了したら、次の検証スクリプトを呼び出して、検証セットとテスト セットでのモデルのパフォーマンスを検証する必要があります。
# -*- coding: utf-8 -*-
# @File : eval.py
# @Author : 嘟粥yyds
# @Time : 2023/08/25
import tensorflow as tf
from data import dev_db, test_db
from models import my_densenet
import settings
# 创建模型
model = my_densenet()
# 加载参数
model.load_weights(settings.MODEL_PATH)
# 因为想用tf.keras的高级接口做验证,所以还是需要编译模型
model.compile(tf.keras.optimizers.Adam(settings.LEARNING_RATE), loss=tf.keras.losses.categorical_crossentropy,
metrics=['accuracy'])
# 验证集accuracy
print('dev', model.evaluate(dev_db))
# 测试集accuracy
print('test', model.evaluate(test_db))
# 查看识别错误的数据
for x, y in test_db:
y_pred = model(x)
y_pred = tf.argmax(y_pred, axis=1).numpy()
y_true = tf.argmax(y, axis=1).numpy()
batch_size = y_pred.shape[0]
for i in range(batch_size):
if y_pred[i] != y_true[i]:
print('{} 被错误识别成 {}!'.format(settings.CODE_CLASS_MAP[y_true[i]], settings.CODE_CLASS_MAP[y_pred[i]]))
16/16 [==============================] - 9s 99ms/step - loss: 0.1439 - accuracy: 0.9713
dev [0.1438767910003662, 0.9713114500045776]
16/16 [==============================] - 1s 85ms/step - loss: 0.1606 - accuracy: 0.9549
test [0.16057191789150238, 0.9549180269241333]
猫 被错误识别成 兔!
猫 被错误识别成 鼠!
猫 被错误识别成 狗!
鼠 被错误识别成 兔!
猫 被错误识别成 狗!
兔 被错误识别成 猫!
兔 被错误识别成 猫!
猫 被错误识别成 兔!
狗 被错误识别成 鼠!
猫 被错误识别成 兔!
狗 被错误识别成 兔!
検証セットのモデルの精度は 97.13%、テスト セットの精度は 95.49% で、期待どおりの結果となっていることがわかります。結局のところ、使用されたデータは非常に小さいのです。
5 モデルの展開
このプロジェクトのモデル展開でも、展開に Gradio が使用されており、その利点は自明のこと、つまり利便性です。
import gradio as gr
import tensorflow as tf
import settings
from models import my_densenet
import matplotlib as mpl
mpl.use('TkAgg')
# 导入模型
model = my_densenet()
# 加载训练好的参数
model.load_weights(settings.MODEL_PATH)
def classify_pet_image(input_image):
"""
宠物图片分类接口,上传一张图片,返回此图片上的宠物是哪种类别,概率多少
"""
# 进行数据预处理
# x = tf.image.decode_image(input_image, channels=3)
x = tf.convert_to_tensor(input_image)
x = tf.image.resize(x, (224, 224))
x = x / 255.
x = (x - tf.constant(settings.IMAGE_MEAN)) / tf.constant(settings.IMAGE_STD)
x = tf.reshape(x, (1, 224, 224, 3))
# 预测
y_pred = model(x)
pet_cls_code = tf.argmax(y_pred, axis=1).numpy()[0]
pet_cls_prob = float(y_pred.numpy()[0][pet_cls_code])
pet_cls_prob = '{}%'.format(int(pet_cls_prob * 100))
pet_class = settings.CODE_CLASS_MAP.get(pet_cls_code)
# 格式化输出为纯文本
output_text = "宠物类别:{} \n概率:{}".format(pet_class, pet_cls_prob)
return output_text
gr.close_all()
demo = gr.Interface(fn=classify_pet_image,
inputs=[gr.Image(label="Upload image")],
outputs=[gr.Textbox(label="识别结果")],
title="宠物识别Demo",
description="Classify your pet!",
allow_flagging="never"
)
demo.launch(share=True, debug=True, server_port=10055)
6 プロジェクトアドレス
Github にアクセスできない場合は、ブロガーのホームページのリソースからダウンロードすることもできます。