Android Studio4.0+OpenCv4.3.0实战之SVM算法

本文主要介绍如何使用Android Studio4.0+OpenCv4.3.0中的SVM算法手写体数字识别。(阅读本文之前最好对SVM算法基础原理有一点了解)

一、简单介绍一下SVM

SVM全称Support Vector Machine,中文:支持向量机,这是一种监督学习算法。具体原理我感觉还是很复杂的,这里我就不具体说了。我就尝试一下用大白话简单描述一下,希望有助于理解。
在两类样本之间按照间隔最大化绘制一条直线(或一个平面),后面的样本会依据这一条直线(或一个平面)判断自己属于哪一类。
这里我们就大致清楚SVM算法的使用流程了。首先我们需要有已经分好类的样本,将样本按照分类属性进行训练模型,训练好的模型则可以对未知分类的样本进行分类。
由于SVM具有参数少,不容易过拟合等优点,在小样本情况下效果可能优于人工神经网络。还是很值的学习的。

二、常用API介绍

创建SVM。使用SVM首先要创建一下SVM,创建完需要给SVM配置一些参数,这个会在第三部分体现。

    public static SVM create() {
        return SVM.__fromPtr__(create_0());
    }

训练API,SVM有很多训练相关的API,考虑初学者,我们这里选取trainAuto()参数最少的API进行介绍。

    public boolean trainAuto(Mat samples, int layout, Mat responses) {
        return trainAuto_8(nativeObj, samples.nativeObj, layout, responses.nativeObj);
    }

samples放置样本集,你可以是一行作为一个样本或者一列作为一个样本,当然后面参数需要同步
layout样本类型ROW_SAMPLE或者COL_SAMPLE,如果你的样本集是一行一个样本就使用ROW_SAMPLE,如果你的样本集是一列一个样本就使用COL_SAMPLE。
responses标签集,需要注意的是标签集需要与样本集对应,确保每一个样本都有一个标签。
补充samples需要转换为CV_32F,responses需要转换为CV_32S
输出测试结果

public float predict(Mat samples) {
    return predict_2(nativeObj, samples.nativeObj);
}

samples为输入测试样本,其格式需要和样本集中样本保持一致.
return返回测试结果,也就是测试出样本对应标签

三、牛刀小试

首先找到opencv提供的digits.png作为样本集以及测试集,这是一张20001000的图片每一个小数字的像素为2020。当然这张图片是没有办法直接使用的我们首先需要对这张图片进行切割,切割成一张一张的小数字图片。
在这里插入图片描述
首先我们把这张图片导入Mat中(比较简单,默认大家都导入了),需要注意的是原图像素为2000*1000你导入Mat中图片的大小会发生变化(部分API会改变图片尺寸),一旦图片大小发生变化那么每一个数字的尺寸也会发生变化,因此切割图片之前需要先检测图片尺寸。我们这里把图片导入到rgbMat中,使用如下代码。

        String digitsSize = String.format("digitsSize %d  %d ", rgbMat.cols(),rgbMat.rows());
        Log.d("digitsSize",digitsSize);

我们可以在Logcat中看到实际导入图像的尺寸信息,我这里的是70003500,也就是说我的每一个数字的尺寸是7070。知道这个参数下面我们就可以做切割了。
在这里插入图片描述
切割代码,我这里是把图像切割完按照不同数字分别放在不同文件夹中并且用数字命名文件夹,图片命名从0开始往后排。

        //颜色空间转换,把原图像灰度化
        Imgproc.cvtColor(rgbMat, grayMat, Imgproc.COLOR_RGB2GRAY);
        //获取原图像尺寸
        int nWidth = grayMat.cols();
        int nHeight  = grayMat.rows();
        //切割后单位模块大小
        int b = 70;
        //计算横向个数
        int m = nWidth / b;
        //计算列向个数
        int n = nHeight / b;
        //文件夹名,文件个数(文件名)
        int  filename = 0,filenum=0;
        /**
         * 图片分割
         */
        for(int i = 0; i < m ;i++){
            //行开始位置
            int startsetRow = i*b;
            //由于每一个数字有五行,所以五行处理完,就进入下一个文件夹
            if(i%5==0&&i!=0)
            {
                filename++;
                filenum=0;
            }
            for (int j = 0; j < n; j++)
            {
                int startsetCol = j*b; //列上的偏移量
                //截取70*70的小块
                Mat tmp = new Mat(b,b,CvType.CV_8UC1);
                for(int x = 0;x < b;x++)
                    for(int y = 0;y < b;y++){
                        tmp.put(x,y,grayMat.get(startsetRow+x,startsetCol+y)[0]);
                    }
                String file = String.format("%d", filename);
                File fileDir = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_PICTURES),file);
                    if(!fileDir.exists()){
                        fileDir.mkdirs();
                    }
                String name = String.format("%d.jpg", filenum++);
                File tempFile = new File(fileDir.getAbsoluteFile()+File.separator,name);
                Imgcodecs.imwrite(tempFile.getAbsolutePath(),tmp);
            }
        }

最后我们看一下结果。我们可以看到没有问题,所有数字都完美切割出来,并且放在对应的文件夹下面。
在这里插入图片描述
有了基础图片,我们就可以创造训练集和标签集了并且训练啦。

		Bitmap  trainBitmap;
       //用于加载图片
        Mat viewMat1 = new Mat();
        //创建标签矩阵
        Mat logo = new Mat(1,1,CV_8UC1);
        //样本集
        Mat trainingImages = new Mat();
        //标签集
        Mat trainingLabels = new Mat();
        /**
         * 生成样本集和标签集(需要保证样本和标签对应)
         */
        for(int i = 0; i < 10 ;i++){
            //读取文件夹路径
            String file = String.format("%d", i);
            File fileDir = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_PICTURES),file);
            //我们只是做简单实验,没有特别要求,所以我们每一个数字只选择10张图片放在样本集里
            for(int j = 0; j < 10 ; j++){
                //读取图片路径
                String name = String.format("%d.jpg", j);
                File tempFile = new File(fileDir.getAbsoluteFile()+File.separator,name);
                //读取图片Bitmap类型
                trainBitmap = BitmapFactory.decodeFile(tempFile.getAbsolutePath());
                //转换成Mat类型
                Utils.bitmapToMat(trainBitmap, viewMat1);
                /**
                 * 对应导入样本和标签
                 */
                //颜色空间转换,灰度化
                Imgproc.cvtColor(viewMat1, viewMat1, Imgproc.COLOR_RGB2GRAY);
                //转换成行矩阵
                viewMat1 = viewMat1.reshape(0,1);
                //把得到的行矩阵连接在trainingImages最后一行作为样本
                trainingImages.push_back(viewMat1);
                //标签赋值
                logo.put(0,0,i);
                //把这个图片样本对于的标签加载在trainingLabels最后一行
                trainingLabels.push_back(logo);
            }
        }
        /**
        * 训练
        */
        //创建svm
        SVM svm = SVM.create();
        //设置SVM相关参数
        svm.setType(SVM.C_SVC);
        svm.setKernel(SVM.LINEAR);
        svm.setTermCriteria(new TermCriteria((TermCriteria.EPS + TermCriteria.COUNT),100,1));
        svm.setC(0.1);
        //类型转换
        trainingImages.convertTo(trainingImages, CV_32FC1);
        trainingLabels.convertTo(trainingLabels, CV_32SC1);
        //训练模型
        svm.trainAuto(trainingImages,ROW_SAMPLE,trainingLabels);
        Log.i(TAG, " svm.trainAuto sucess...");

训练完成我们来找图片测试一下,我是一个懒人所以我直接用从digits.png切割出来的图片做测试,由于训练模型的时候用了每一个数字前十张图片,所以测试我随机选取每个数字文件夹下编号10以上图片来测试效果.代码如下
//选取数字8文件夹

        Bitmap  bmp;
        Mat text = new Mat();
        //选取数字8文件夹
        String file = String.format("%d", 8);
        File fileDir = new File(Environment.getExternalStoragePublicDirectory(Environment.DIRECTORY_PICTURES),file);
        //选择编号100的图片
        String name = String.format("%d.jpg", 100);
        File tempFile = new File(fileDir.getAbsoluteFile()+File.separator,name);
        //导入图片
        bmp = BitmapFactory.decodeFile(tempFile.getAbsolutePath());
        //转换为Mat
        Utils.bitmapToMat(bmp, text);
        //灰度化
        Imgproc.cvtColor(text, text, Imgproc.COLOR_RGB2GRAY);
        //转成样本格式
        text = text.reshape(0,1);
        text.convertTo(text, CV_32FC1);
        //测试结果
        float response = svm.predict(text);
        //打印出样本应该属于的标签
        String result = String.format("result %f ", response);
        Log.d("text",result);

我们来看一下结果,结果和期望的一样,成功识别出测试图像为8.
在这里插入图片描述

到这样SVM实现手写数字识别算是完成了,感谢您可以看到这里。
希望本文可以对您的学习有帮助,您有任何疑问或者新的想法欢迎评论留言。
本人能力有限,如文中存在错误或者不足还请指出,非常感谢。

猜你喜欢

转载自blog.csdn.net/qq_41814560/article/details/107592545