Mnist手写体的分类预测

下面是python代码:

# coding=UTF-8
"""
使用训练好的caffe模型预测手写体程序
"""

import os
import sys
import numpy as np
import matplotlib.pyplot as plt
caffe_root="/home/pcb/caffe/"    #设置caffee的目录
sys.path.insert(0,caffe_root+"python")
import caffe
#指定LetNet的网络定义模型
Model_file="/home/pcb/caffe/examples/mnist/lenet.prototxt"
#加载训练好的Model模型
Pretrained="/home/pcb/caffe/examples/mnist/lenet_iter_10000.caffemodel"
#测试图片路径
Image_file="/home/pcb/caffe/data/mnist/train_13.bmp"

#caffe接口载入文件
input_image=caffe.io.load_image(Image_file,color=False)

#载入LeNet分类器
net=caffe.Classifier(Model_file,Pretrained)

prediction=net.predict([input_image],oversample=False)
caffe.set_mode_cpu() #设置为CPU模式
print "predicted calss:",prediction[0].argmax()

最终预测结果为:
这里写图片描述
选择了一个6的手写体图像,然后最后分类结果也是6,这样就正确分类了!

猜你喜欢

转载自blog.csdn.net/pcb931126/article/details/81412059
今日推荐