基于BP神经网络的数字识别基础系统(四)

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/z_x_1996/article/details/68633264

基于BP神经网络的数字识别基础系统(四)

接上篇

上一篇的链接:http://blog.csdn.net/z_x_1996/article/details/68490009

3.系统设计

上一篇笔者已经讨论完了BP神经网络需要用到的知识点,接下来就开始设计符合我们标题的系统了。

首先我们要确定训练集以及测试集:下载链接:http://download.csdn.net/detail/z_x_1996/9799552


我们来分析训练集,首先训练的图片格式为bmp位图格式,位深度为8,分辨率为32*64,训练集分为0~9十个文件夹,每个文件夹里面有4张不同字体的相同数字(数字同文件夹名称),同时训练集里有一个target.txt文件,里面文件代表每一张图片的目标输出,一行就是一张图的目标输出,我们很容易看出输出有10个单元,每个数字对应一组输出。这里并没有采用二进制编码而是采用一对一编码,这样的好处在于可以很容易获得置信度,但是坏处也是显而易见的,那就是当样本类型很多时网络的输出会急剧增加。

我们再来看输入层,为了精简输入信息,我们将图片压缩,横竖均只取1/4的像素,均匀分布。这样输入单元有32*64/16=128个输入单元。

隐藏层有多种选择,首先确定隐藏层数,考虑到该数据组分类比较简单,故选择一层隐藏层,这层的单元数有多种选择,不同的选择会有不同的影响,这个影响我们后面再谈(如果忘了请记得提醒笔者),这里我们选择为4个。

至此我们便确定了网络结构,三层:

  • 输入层:128单元
  • 隐藏层:8单元
  • 输出层:10单元

这样我们也可以把权重向量的size确定了:

  • weightHK[][]:10x(8+1)
  • weightIH[][]:8x(128+1)

(这里+1的原因是要加上一个常数偏置项)

首先笔者先给出系统工程的结构图:

3.1 神经网络包

我们先构建神经网络元素包 com.zhangxiao.element。

首先自然来到我们SNeuron.java文件,该文件为一个神经元。

package com.zhangxiao.element;

public class SNeuron {
    private double[] weight;
    private double[] input;
    private int length;

    public SNeuron(double[] input,double[] weight){
        this.input = input;
        this.length = input.length;
        this.weight = weight;
    }

    //获得Sigmoid输出结果。
    public double getResult(){
        double sum = weight[0];
        for(int i=0;i<length;i++){
            sum += input[i]*weight[i+1];
        }
        return 1/(Math.exp(-sum)+1);
    }

}

没有什么好说的,然后是构建一层 Layer.java,该文件为一层的类。

package com.zhangxiao.element;

public class Layer {

    private SNeuron[] cells;
    private int number;
    private double[] input;
    private double[] output;
    private double[][] weight;
    //初始化神经层
    public Layer(int number,double[] input,double[][] weight){
        this.number = number;
        this.input = input;
        this.weight = weight;
        output = new double[number];
        cells = new SNeuron[number];
        for(int i=0;i<number;i++){
            cells[i] = new SNeuron(this.input,this.weight[i]);
        }
    }
    //获得神经层输出结果数组
    public void goForward() {
        for(int i=0;i<number;i++){
            output[i] = cells[i].getResult();
        }
    }

    public double[] getOutput() {
        return output;
    }
}

然后是构建一个神经系统(目前笔者写的代码只支持3层,即一个隐藏层),NervousSystem1H.java

package com.zhangxiao.element;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;

public class NervousSystem1H {
    private double[][] trainData;
    private double[] input;
    private double[] output;
    private double[] connection;
    public double[] getInputLayer() {
        return input;
    }

    private double[][] target;
    private Layer[] layers;
    private int[] structure;
    private double efficiency;
    private double[] deltaK;
    private double[] deltaH;

    private double[][] weightIH;
    private double[][] weightHK;        

    //初始化神经系统
    public NervousSystem1H(double efficiency,int[] structure,double[][] trainData,double[][] target) throws IOException{

        if(trainData[0].length!=structure[0]){
            System.out.println("训练数据长度与输入层长度不一致!");
            return;
        }

        this.trainData = trainData;
        this.target = target;
        this.efficiency = efficiency;
        this.structure = structure;

        //初始化数组
        this.input = new double[structure[0]];
        deltaK = new double[structure[2]];
        deltaH = new double[structure[1]];      
        for(int k=0;k<deltaK.length;k++){
            deltaK[k] = 0;
        }
        for(int h=0;h<deltaH.length;h++){
            deltaH[h] = 0;
        }
        weightIH = new double[structure[1]][structure[0]+1];
        weightHK = new double[structure[2]][structure[1]+1];
        for(int h=0;h<structure[1];h++){
            for(int i=0;i<structure[0]+1;i++){
                while(Math.abs((weightIH[h][i] = Math.random()/10-0.05))==0){}
            }
        }
        for(int k=0;k<structure[2];k++){
            for(int h=0;h<structure[1]+1;h++){
                while(Math.abs(weightHK[k][h] = Math.random()/10-0.05)==0){}
            }
        }

        //连接各层
        layers= new Layer[2];
        layers[0] = new Layer(structure[1],this.input,weightIH);
        connection = layers[0].getOutput();
        layers[1] = new Layer(structure[2],connection,weightHK);
        this.output = layers[1].getOutput();

    }

    //训练神经网络
    public void train() throws IOException{
        double error = 0;
        int process = 0;
        while((error = getError())>0.0001){
            System.out.println(process++ +":"+error);
            for(int d=0;d<trainData.length;d++){
                //正向传播输出
                goForward(trainData[d]);

                double[] outputK = layers[1].getOutput();
                double[] outputH = layers[0].getOutput();

                for(int k=0;k<deltaK.length;k++){
                    deltaK[k] = outputK[k]*(1-outputK[k])*(target[d][k]-outputK[k]);
                }
                for(int h=0;h<deltaH.length;h++){
                    deltaH[h] = 0;
                    for(int k=0;k<deltaK.length;k++){
                        deltaH[h] += outputH[h]*(1-outputH[h])*deltaK[k]*weightHK[k][h+1];
                    }
                }
                //更新权值

                for(int k=0;k<weightHK.length;k++){
                    weightHK[k][0] += efficiency*deltaK[k];
                    for(int h=1;h<weightHK[0].length;h++){
                        weightHK[k][h] += efficiency*deltaK[k]*outputH[h-1];
                    }
                }

                for(int h=0;h<weightIH.length;h++){
                    weightIH[h][0] += efficiency*deltaH[h];
                    for(int i=1;i<weightIH[0].length;i++){
                        weightIH[h][i] += efficiency*deltaH[h]*trainData[d][i-1];
                    }
                }
            }
        }
        System.out.println("最终误差为:"+getError());
    }

    //获取输出结果数组
    public void goForward(double[] input){
        setInput(input);
        for(int i = 0;i<structure.length-1;i++){
            layers[i].goForward();
        }
    }

    //获取误差
    public double getError(){
        double error = 0;
        for(int d=0;d<trainData.length;d++){
            goForward(trainData[d]);
            for(int i=0;i<target[0].length;i++){
                error += 0.5*(target[d][i]-output[i])*(target[d][i]-output[i]);
            }
        }
        return error/trainData.length/10;
    }

    //将训练好的权重保存到txt文件中方便查看以及二次调用
    public boolean saveWeight(File file) throws IOException{
        boolean flag = false;
        BufferedWriter bw = new BufferedWriter(new FileWriter(file));
        //写入weightIH
        for(int h=0;h<weightIH.length;h++){
            for(int i=0;i<weightIH[0].length;i++){
                bw.append(Double.toString(weightIH[h][i])+" ");
            }
            bw.append("\r\n");
            bw.flush();
        }
        //写入weightHK
        for(int k=0;k<weightHK.length;k++){
            for(int h=0;h<weightHK[0].length;h++){
                bw.append(Double.toString(weightHK[k][h])+" ");
            }
            bw.append("\r\n");
            bw.flush();
        }
        bw.close();
        return flag;
    }

    //调用训练好的网络
    public boolean loadWeight(File file) throws IOException{
        boolean flag = false;
        BufferedReader br = new BufferedReader(new FileReader(file));
        //写入weightIH
        String line;
        String[] strs;
        for(int h=0;h<weightIH.length;h++){
            line=br.readLine();
            strs = line.split(" ");
            for(int i=0;i<weightIH[0].length;i++){
                weightIH[h][i] = Double.parseDouble(strs[i]);
            }
        }
        //写入weightHK
        for(int k=0;k<weightHK.length;k++){
            line=br.readLine();
            strs = line.split(" ");
            for(int h=0;h<weightHK[0].length;h++){
                weightHK[k][h] = Double.parseDouble(strs[h]);
            }
        }
        br.close();
        return flag;
    }

    //网络每个输出单元的输出
    public double[] predict_all(double[] input){
        goForward(input);
        return output;
    }

    //输出预测数字
    public int preidict_result(double[] input){
        int result = -1;
        double max = -1;
        goForward(input);
        for(int i=0;i<output.length;i++){
            if(output[i]>max){
                max = output[i];
                result = 9-i;
            }
        }
        return result;
    }

    private void setInput(double[] input) {
        for(int i=0;i<this.input.length;i++){
            this.input[i] = input[i];
        }
    }

    public double[][] getWeightIH() {
        return weightIH;
    }

    public double[][] getWeightHK() {
        return weightHK;
    }

}

这里需要说明的是主要的计算量为goForward函数,这个是正向计算的函数。如果看懂了前面的原理这个文件其实也没什么好讲的,无非是把输出细节化,训练方法和前面所说一样。同样增加了getError函数来获取误差,因为笔者把Error来作为训练终止的要求。但是其实使用这个作为终止条件摒弃了增量梯度下降算法中不需要一次性加载所有数据的优点。计算Error必须使用所有的数据。

这样一个网络的架构就已经搭建好了,使用时我们只需要调用NervousSystem1H类中的方法就可以了。

3.2 主程序包

下面就是要创建针对项目的主程序包com.zhangxiao.window了。

Window.java中主要应该包括如下方法:

  • 获取训练数据
  • 图片数据转化为数组
  • 获取训练标签
  • 构建神经网络

    package com.zhangxiao.window;
    
    import java.awt.image.BufferedImage;
    import java.io.BufferedReader;
    import java.io.File;
    import java.io.FileReader;
    import java.io.IOException;
    import javax.imageio.ImageIO;
    
    import com.zhangxiao.element.NervousSystem1H;
    
    public class Window {
    
        public static void main(String[] args) throws IOException {
            String path = "这里填自己的路径";
            //获取训练素材
            double[] testData = new double[128];//这里记得获取测试数据!!!!!!!!!!!!!!!!!!
            double[][] target = getTarget(path, 10, 40);
            double[][] trainData = getTrainData(path, 40);
            int[] structure = new int[]{128,8,10};
            //构建神经网络
            NervousSystem1H s = new NervousSystem1H(0.01,structure,trainData,target);
            //训练神经网络
            System.out.println("训练中...");
            s.train();
            System.out.println("训练完毕!");    
            //保存weight数据
            s.saveWeight(new File("data/weight/weight.txt"));
    
            //载入保存的weight数据
            /*System.out.println("载入中...");
            s.loadWeight(new File("data/weight/weight.txt"));
            System.out.println("载入完成!");*/
    
            double[] result = s.predict_all(testData);
            for(int i=0;i<result.length;i++){
                System.out.print(result[i]+" ");
            }
        }
    
        //获取训练样本
        public static double[][] getTrainData(String direction,int number) throws IOException{
            double[][] trainData = new double[number][128];
            for(int d=0;d<number/4;d++){
                for(int i=0;i<4;i++){
                    trainData[4*d+i] = image2Array(direction+"/"+d+"/"+d+""+i+".bmp");
                }
            }
            return trainData;
        }
    
        //将图片转化为数组
        public static double[] image2Array(String str) throws IOException{
            double[] data = new double[16*8];
            BufferedImage image = ImageIO.read(new File(str));
            for(int i = 0;i<8;i++){
                for(int j = 0;j<16;j++){
                    int color = image.getRGB(4*i, 4*j);
                    int b = color&0xff;
                    int g = (color>>8)&0xff;
                    int r = (color>>8)&0xff;
                    data[8*j+i]=((int)(r*0.3+g*0.59+b*0.11))/255;
                }
            }
            return data;
        }
    
        //获取目标结果数组
        @SuppressWarnings("resource")
        public static double[][] getTarget(String str,int length,int number) throws IOException{
            BufferedReader br = new BufferedReader(new FileReader(str));
            double[][] data = new double[number][length];
            String line;
            String[] strs;
            int d = 0;
            while((line=br.readLine())!=null){
                strs = line.split(" ");
                for(int i=0;i<length;i++){
                    data[d][i] = Double.parseDouble(strs[i]);
                }
                d++;
            }
            if(d!=number){
                System.out.println("数据组数不匹配!");
                return null;
            }
            br.close();
            return data;
        }
    
    }
    

4.后记

到这里这个坑基本上算是填完了,当然笔者还是需要说明的是由于代码写的比较匆忙,很多冗余、不够优化以及结构问题比比皆是。希望大家能够谅解,如果有很好的建议方便留言。到目前为止这个系列笔者前前后后花费了很多的精力以及时间,终于完成了这个两万多字的系列,可以说从中也学到了很多东西,很多以前并不是很清楚的东西也理清楚了。另外这里给出大家一个优化的方向:

  • 加入冲量项,避开局部最小值。
  • 改变隐藏层的单元数。

如果觉得看完还是有些疑惑的建议自己再复建一下算法,或者你可以试试将隐藏层数变为2层,再来思考整个系统,相信你会受益匪浅!

猜你喜欢

转载自blog.csdn.net/z_x_1996/article/details/68633264
今日推荐