一、前言
上一篇博客《有趣的卷积神经网络》介绍如何基于deeplearning4j对手写数字识别进行训练,对于整个训练集只训练了一次,正确率是0.9897,随着迭代次数的增加,网络模型将更加逼近训练集,下面是对训练集迭代十次的评估结果,总之迭代次数的增加会更加逼近模型(注:增加迭代次数有时也会发生过拟合,有时候也并非很奏效,具体情况具体分析)。
Accuracy: 0.9919
Precision: 0.9919
Recall: 0.9918
F1 Score: 0.9918
二、导读
1、web环境搭建
2、基于canvas构建前端画图界面
3、整合dl4j训练模型
三、web环境搭建
1、eclipse new一个Maven project ,填好maven坐标,packaging选war
<groupId>org.dl4j</groupId>
<artifactId>digitalrecognition</artifactId>
<version>0.0.1-SNAPSHOT</version>
<packaging>war</packaging>
2、配置Jar包依赖,由于servlet-api一般由web容器提供,所以scope为provided,这样不会被打入war包里。
<dependencies>
<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-webmvc</artifactId>
<version>4.3.4.RELEASE</version>
</dependency>
<dependency>
<groupId>javax.servlet</groupId>
<artifactId>servlet-api</artifactId>
<version>2.5</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<version>2.5.3</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-annotations</artifactId>
<version>2.5.3</version>
</dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<version>2.5.3</version>
</dependency>
<dependency>
<groupId>commons-fileupload</groupId>
<artifactId>commons-fileupload</artifactId>
<version>1.3.1</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>0.9.1</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>0.9.1</version>
</dependency>
</dependencies>
3、为了开发方便,不用把web工程部署到外置web容器,所以在开发时用mavan tomcat插件是比较方便的。运行时mvn tomcat7:run即可
<build>
<plugins>
<plugin>
<groupId>org.apache.tomcat.maven</groupId>
<artifactId>tomcat7-maven-plugin</artifactId>
<version>2.2</version>
<configuration>
<uriEncoding>UTF-8</uriEncoding>
<path>/</path>
<port>8080</port>
<protocol>org.apache.coyote.http11.Http11NioProtocol</protocol>
<maxThreads>1000</maxThreads>
<minSpareThreads>100</minSpareThreads>
</configuration>
</plugin>
</plugins>
</build>
4、web常规配置web.xml,filter、servlet、listener这里就略去了。
四、前端canvas画图实现
1、html元素、css
<style type="text/css">
body {
padding: 0;
margin: 0;
background: white;
}
#canvas {
margin: 100px 0 0 300px;
}
#canvas>span {
color: white;
font-size: 14px;
}
#result {
margin: 0px 0 0 300px;
}
</style>
<html>
<head>
<title>数字识别</title>
</head>
<body>
<canvas id="canvas" width="280" height="280"></canvas>
<button onclick="predict()">预测</button>
<div id="result">
识别结果:<font size="18" id="digit"></font>
</div>
</body>
</html>
2、js代码实现在canvas画布连线操作,并将图片转化为base64格式,ajax发送给后端,这里画布的大小是280px,所以图片到了后端,需要缩小至十分之一。
<script src="/js/jquery-3.2.1.min.js"></script>
<script type="text/javascript">
/*获取绘制环境*/
var canvas = $('#canvas')[0].getContext('2d');
canvas.strokeStyle = "white";//线条的颜色
canvas.lineWidth = 10;//线条粗细
canvas.fillStyle = 'black'
canvas.fillRect(0, 0, 280, 280);
$('#canvas').on('mousedown', function() {
/*开始绘制*/
canvas.beginPath();
/*设置动画绘制起点坐标*/
canvas.moveTo(event.pageX - 300, event.pageY - 100);
$('#canvas').on('mousemove', function() {
/*设置下一个点坐标*/
canvas.lineTo(event.pageX - 300, event.pageY - 100);
/*画线*/
canvas.stroke();
});
}).on('mouseup', function() {
$('#canvas').off('mousemove');
});
function predict() {
var img = $('#canvas')[0].toDataURL("image/png");
$.ajax({
url : "/digitalRecognition/predict",
type : "post",
data : {
"img" : img.substring(img.indexOf(",") + 1)
},
success : function(response) {
$("#digit").html(response);
},
error : function() {
}
});
}
</script>
整体呈现的界面如下,可以画图。
五、后端java代码
@RequestMapping("/digitalRecognition")
@Controller
public class DigitalRecognitionController implements InitializingBean {
private MultiLayerNetwork net;
@ResponseBody
@RequestMapping("/predict")
public int predict(@RequestParam(value = "img") String img) throws Exception {
String imagePath= generateImage(img);//将base64图片转化为png图片
imagePath= zoomImage(imagePath);//将图片缩小至28*28
DataNormalization scaler = new ImagePreProcessingScaler(0, 1);
ImageRecordReader testRR = new ImageRecordReader(28, 28, 1);
File testData = new File(imagePath);
FileSplit testSplit = new FileSplit(testData, NativeImageLoader.ALLOWED_FORMATS);
testRR.initialize(testSplit);
DataSetIterator testIter = new RecordReaderDataSetIterator(testRR, 1);
testIter.setPreProcessor(scaler);
INDArray array = testIter.next().getFeatureMatrix();
return net.predict(array)[0];
}
private String generateImage(String img) {
BASE64Decoder decoder = new BASE64Decoder();
String filePath = WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png";
try {
byte[] b = decoder.decodeBuffer(img);
for (int i = 0; i < b.length; ++i) {
if (b[i] < 0) {
b[i] += 256;
}
}
OutputStream out = new FileOutputStream(filePath);
out.write(b);
out.flush();
out.close();
} catch (Exception e) {
e.printStackTrace();
}
return filePath;
}
private String zoomImage(String filePath){
String imagePath=WebConstant.WEB_ROOT + "upload/"+UUID.randomUUID().toString()+".png";
try {
BufferedImage bufferedImage = ImageIO.read(new File(filePath));
Image image = bufferedImage.getScaledInstance(28, 28, Image.SCALE_SMOOTH);
BufferedImage tag = new BufferedImage(28, 28, BufferedImage.TYPE_INT_RGB);
Graphics g = tag.getGraphics();
g.drawImage(image, 0, 0, null); // 绘制处理后的图
g.dispose();
ImageIO.write(tag, "png",new File(imagePath));
} catch (Exception e) {
e.printStackTrace();
}
return imagePath;
}
@Override
public void afterPropertiesSet() throws Exception {
net = ModelSerializer.restoreMultiLayerNetwork(new File(WebConstant.WEB_ROOT + "model/minist-model.zip"));
}
}
代码说明:
1、InitializingBean是spring bean生命周期中的一个环节,spring构建bean的过程中会执行afterPropertiesSet方法,这里用这个方法来加载已经定型的网络。
2、generateImage是用来将前端传过来的base64串转化为png格式。
3、zoomImage方法将前端的280*280缩小至28*28和训练数据一致,并存到webroot的upload目录下。
4、predict进行预测,将转化好的28*28的图片读取出来,张量化,把像素点的值压缩至0到1,预测,最后结果是一个数组,由于只有一张图片,取数组的第一个元素即可。
六、测试,mvn tomcat7:run,浏览器访问http://localhost:8080即可玩手写数字识别了
测试结果马马虎虎,大体上实现了基本功能。
git地址:https://gitee.com/lxkm/dl4j-demo/tree/master/digitalrecognition
快乐源于分享。