使用slim训练花分类
下载安装slim
通常clone一些google的models
即可。下面还是分类花,处理的时候步完全按照slim教程来,查看官方教程在这里
安装完成后直接使用自带脚本就能下载数据集同时转换tfrecord文件,为了加深理解,这里自己下载数据集同时手动转换tfrecord文件。
生成tfrecord数据集
如果你的服务器能联网直接按照官网教程输入
脚本如下:
$ DATA_DIR=/tmp/data/flowers
$ python download_and_convert_data.py \
--dataset_name=flowers \
--dataset_dir="${DATA_DIR}"
如果没网需要自己下载flower_photos.tgz文件将.tgz文件放在DATA_DIR
下面。
使用inceptionv1训练flower
下载wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz
,文件夹里面是一个ckeckpoint文件
训练脚本如下:
# 开启2xGPU
export CUDA_VISIBLE_DEVICES=0,1
# tfrecord文件所在
DATASET_DIR=/ssd/flower_tfrecord
# 训练文集存放地
TRAIN_DIR=/ssd/flower_slim_train
# inceptionV1 checkpoint文件位置
CHECKPOINT_PATH=/home/mc12/models/research/slim/inception_v1.ckpt
python train_image_classifier.py \
--train_dir=${TRAIN_DIR} \
--dataset_dir=${DATASET_DIR} \
--dataset_name=flowers \
--dataset_split_name=train \
--model_name=inception_v1 \
--checkpoint_path=${CHECKPOINT_PATH} \
--checkpoint_exclude_scopes=InceptionV1/Logits,InceptionV1/AuxLogits \
--trainable_scopes=InceptionV1/Logits,InceptionV1/AuxLogits
训练结果如下:
评估训练结果(在新的机器上训练Inception V3,这里参数为新的机器上的,v1的该v3为对应的数据即可):
python eval_image_classifier.py \
--dataset_name=flowers \
# tfrecord数据所在目录
--dataset_dir=/home/liushuai/tensorflow_hub_demo \
--dataset_split_name=train \
# 模型,如果是v1选择这里的v1
--model_name=inception_v3 \
--checkpoint_path=/tmp/tfmodel \
# 评估数据所在目录
--eval_dir=/home/liushuai/tensorflow_hub_demo \
--batch_size=32
输出结果:
INFO:tensorflow:Evaluation [10/104]
INFO:tensorflow:Evaluation [20/104]
INFO:tensorflow:Evaluation [30/104]
INFO:tensorflow:Evaluation [40/104]
INFO:tensorflow:Evaluation [50/104]
INFO:tensorflow:Evaluation [60/104]
INFO:tensorflow:Evaluation [70/104]
INFO:tensorflow:Evaluation [80/104]
INFO:tensorflow:Evaluation [90/104]
INFO:tensorflow:Evaluation [100/104]
INFO:tensorflow:Evaluation [104/104]
eval/Accuracy[0.95703125]
eval/Recall_5[1]