主要参考这个项目并对其修改,以flowers为例:
完整代码发在了github上
先准备数据集:pic/train、pic/validation。运行:
python data_convert.py -t pic/ \
--train-shards 2 \
--validation-shards 2 \
--num-threads 2 \
--dataset-name flowers
将.record数据和label.txt复制到slim/data下,并下载mobilenet模型放入model目录下。运行:
python train_image_classifier.py \
--train_dir=flowers/train_log \
--dataset_name=flowers \
--train_image_size=299 \
--dataset_split_name=train \
--dataset_dir=data \
--model_name="mobilenet_v2_140" \
--checkpoint_path=model/mobilenet_v2_1.4_224.ckpt \
--checkpoint_exclude_scopes=MobilenetV2/Logits,MobilenetV2/AuxLogits \
--trainable_scopes=MobilenetV2/Logits,MobilenetV2/AuxLogits \
--max_number_of_steps=20000 \
--batch_size=16 \
--learning_rate=0.001 \
--learning_rate_decay_type=fixed \
--log_every_n_steps=10 \
--optimizer=rmsprop \
--weight_decay=0.00004 \
--label_smoothing=0.1 \
--num_clones=1 \
--num_epochs_per_decay=2.5 \
--moving_average_decay=0.9999 \
--learning_rate_decay_factor=0.98 \
--preprocessing_name="inception_v2"
评估模型:
python eval_image_classifier.py \
--checkpoint_path=flowers/train_log \
--eval_dir=flowers/eval_log \
--dataset_name=flowers \
--dataset_split_name=validation \
--dataset_dir=data \
--model_name="mobilenet_v2_140" \
--batch_size=32 \
--num_preprocessing_threads=2 \
--eval_image_size=299
导出图:
python export_inference_graph.py \
--alsologtostderr \
--model_name="mobilenet_v2_140" \
--image_size=299 \
--output_file=flowers/export/mobilenet_v2_140_inf_graph.pb \
--dataset_name flowers
python freeze_graph.py \
--input_graph slim/flowers/export/mobilenet_v2_140_inf_graph.pb \
--input_checkpoint slim/flowers/train_log/model.ckpt-20000 \
--input_binary true \
--output_node_names MobilenetV2/Predictions/Reshape_1 \
--output_graph slim/flowers/export/frozen_graph.pb
测试单张或多张图片:
python classify_image_test.py \
--model_path slim/flowers/export/frozen_graph.pb \
--label_path data_prepare/pic/label.txt \
--image_file test_image.jpg