Deeplearning4j 实战 (12):Mnist替代数据集Fashion Mnist在CNN上的实验及结果

Mnist数据集的分类问题一直被认为是深度学习的Hello World。利用2层卷积网络,经过若干轮的训练后,在相应测试集上的准确率可以达到95%以上。经过调参后,甚至可以达到99%以上。其实,即使不用用卷积层提取特征,而是用传统的全连接网络也同样可以达到非常高的准确率。在Mnist数据集的官网上(http://yann.lecun.com/exdb/mnist/),除了基于神经网络的分类器,利用传统的分类方法,如:KNN,SVM,也都可以获得非常好的结果。下面就是部分模型分类效果的截图:


从以上结果分析可以发现,无论是浅层模型还是深度学习,在Mnist上的分类问题上都可以达到很高的精度,因此从某种角度也可以说,Mnist数据集复杂度不够,或者说Mnist分类问题并不是一个具有代表性的机器视觉问题。就这个问题,《Deep Learning》一书的作者Ian Goodfellow和著名开源项目Keras的作者Francois Chollet都有自己的评述,详情可转到下面两个链接:

1.Ian Goodfellow Commnet On Mnist DataSet

2.Francois Chollet's Comment

虽然Mnist可能并不是最合适入门深度学习的数据集,但是鉴于长期以来开发人员的使用习惯,想要找到完全替代Mnist的开源数据集确实有点困难,但这个难题最近有了一个比较好的解答,就是类似Mnist的一个服装图像数据集--Fashion Mnist

和Mnist数据集一样,Fashion Mnist也是28*28的灰度图。内容涵盖了鞋、包、衣服、裤子。它的文件名称和数据格式和Mnist一模一样。换句话说,你完全不需要改动你之前在Mnist上的建模逻辑,只需要把相应的文件替换掉,就可以对Fashion Mnist进行训练和评估。不过唯一不同的是,Fashion Mnist的分类准确率远没有Mnist那么高。目前在Fashion Mnist的github主页上,最好的结果也仅仅是在95%左右。当然,如果你自己的网络有了好的结果,可以在主页上提个issue,也作为是对这个数据集的一个贡献。

下面主要介绍3个方面的内容:

1.Fashion Mnist基于CNN的建模分类与评估

2.与Mnist的比较

3.简单的服装分类应用

首先介绍第一部分的主要内容。对于Fashion Mnist数据集采用卷积神经网络进行分类建模,具体的网络结构是:2Conv+2FC。建模工具是Deeplearning4j。详细的超参数配置见如下代码片段:

    public static MultiLayerNetwork getModel(){
        MultiLayerConfiguration.Builder builder = new NeuralNetConfiguration.Builder()
                        .seed(12345)
                        .iterations(1)
                        //.regularization(true).l2(0.005)
                        .learningRate(0.01)
                        .learningRateScoreBasedDecayRate(0.5)
                        .weightInit(WeightInit.XAVIER)
                        .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT)
                        .updater(Updater.ADAM)
                        .list()
                        .layer(0, new ConvolutionLayer.Builder(5, 5)
                                .nIn(1)
                                .stride(1, 1)
                                .nOut(32)
                                .activation(Activation.LEAKYRELU)
                                .build())
                        .layer(1, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                                .kernelSize(2,2)
                                .stride(2,2)
                                .build())
                        .layer(2, new ConvolutionLayer.Builder(5, 5)
                                .stride(1, 1)
                                .nOut(64)
                                .activation(Activation.LEAKYRELU)
                                .build())
                        .layer(3, new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                                .kernelSize(2,2)
                                .stride(2,2)
                                .build())
                        .layer(4, new DenseLayer.Builder().activation(Activation.LEAKYRELU)
                                .nOut(500).build())
                        .layer(5, new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                                .nOut(10)
                                .activation(Activation.SOFTMAX)
                                .build())
                        .backprop(true).pretrain(false)
                        .setInputType(InputType.convolutionalFlat(28, 28, 1));
        MultiLayerConfiguration conf = builder.build();
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        return model; 
    }
简单解释下部分超参数:

激励函数部分主要用的是LeakeyRelu。

学习率用了Decay的策略。Decay的幅度是50%。

正则化项是可选的(经测试,正则化项在如上配置中,影响不大)

网络结构:2 Conv-with-MaxPooling + 2FC。卷积层中每层的featureMap的数量如上述所示。

除了建模的部分,数据的ETL部分同样很重要。在具体实现中,我直接利用Deeplearning4j自带的一个解析Mnist数据集的组件:MnistManager。它的主要功能就是读取解压后的二进制Mnist数据集以及相应的分类标签。由于Fashion Mnist和原始Mnist数据集在数据格式上完全相同,所以可以直接使用Mnist的组件进行解析。在读取的时候,我们可以根据要求设置batchSize,一个batch的数据和标签会封装在一个DataSet对象中。由这些DataSet构成的迭代器即可作为最终训练或者测试的数据。下面具体看下以上逻辑的实现:

    public static DataSet fetch(int batchSize , boolean binarize, MnistManager man, boolean save, boolean train) {        
        float[][] featureData = new float[batchSize][0];
        float[][] labelData = new float[batchSize][0];

        int actualExamples = 0;
        for (int i = 0; i < batchSize && cursor < totalExamples; i++, cursor++) {
            byte[] img = man.readImageUnsafe(order[cursor]);
            int label = man.readLabel(order[cursor]);
            
            float[] featureVec = new float[img.length];
            featureData[actualExamples] = featureVec;
            labelData[actualExamples] = new float[10];
            labelData[actualExamples][label] = 1.0f;

            for (int j = 0; j < img.length; j++) {
                float v = ((int) img[j]) & 0xFF; //byte is loaded as signed -> convert to unsigned
                if (binarize) {
                    if (v > 30.0f)
                        featureVec[j] = 1.0f;
                    else
                        featureVec[j] = 0.0f;
                } else {
                    featureVec[j] = v / 255.0f;
                }
            }
            if( save ){
                Mat mat = new Mat(28, 28, CV_8SC1, new BytePointer(img)); 
                
                if( train )
                    JavaCVUtil.imWrite(mat, "FashionMnist/trainData/" + label + "_" + cursor + ".jpg");
                else
                    JavaCVUtil.imWrite(mat, "FashionMnist/testData/" + label + "_" + cursor + ".jpg");
            }
            actualExamples++;
        }

        if (actualExamples < batchSize) {
            featureData = Arrays.copyOfRange(featureData, 0, actualExamples);
            labelData = Arrays.copyOfRange(labelData, 0, actualExamples);
        }

        INDArray features = Nd4j.create(featureData);
        INDArray labels = Nd4j.create(labelData);
        return new DataSet(features, labels);
    }
    
    public static DataSetIterator getData(String dir, boolean train , int batchSize, boolean save) throws IOException{
        String featureFileDir = dir;
        String labelFileDir = dir;
        cursor = 0;
        if( train ){
            featureFileDir += "train-images-idx3-ubyte";
            labelFileDir += "train-labels-idx1-ubyte";
            totalExamples = 60000;
            order = new int[totalExamples];
        }else{
            featureFileDir += "t10k-images-idx3-ubyte";
            labelFileDir += "t10k-labels-idx1-ubyte";
            totalExamples = 10000;
            order = new int[totalExamples];
        }
        for (int i = 0; i < order.length; i++)order[i] = i;
        MathUtils.shuffleArray(order, 123456L); //shuffle order
        MnistManager man = new MnistManager(featureFileDir, labelFileDir, train);
        List<DataSet> res = new LinkedList<DataSet>();
        while(cursor < totalExamples){
            res.add(fetch(batchSize, false, man, save, train));
        }
        ExistingDataSetIterator iter = new ExistingDataSetIterator(res);
        return iter;
    }

以上两个静态方法就是解析数据、读取标签、封装数据并生成可迭代数据集的过程。其中,getData这个方法可以根据参数的不同,生成训练或者测试数据集。在fetch这个方法里,可以选择是二值化还是正常归一化。我这里选择的是正常归一化。此外,为了方便看到Fashion Mnist的图像形式,可以选择是否以图片的形式生成这些图片。如果生成图片的话,则可以看到下面这些图:



这些图片我会上传到CSDN上供大家下载。下载连接

从截图可以看出是衣服和裤子两个品类。文件名中的第一个数字是这张图片的分类标签。这样方便直接从图片进行建模。训练集共6W张图片,测试集共1W张图片。

到此的话,数据的ETL和建模的步骤都已经完成,下面就是对模型参数进行训练。这里我还是用的GPU来训练模型。显卡是Telsa K80。单卡进行训练。相应的CUDA版本是8.0。具体的训练逻辑可见下面代码片段:

    public static void main(String[] args)throws IOException {
        DataTypeUtil.setDTypeForContext(DataBuffer.Type.DOUBLE);  
        final int numEpochs = Integer.parseInt(args[0]);
        final int batchSize = Integer.parseInt(args[1]);
        final String modelSavePath = args[2];
        final String dataPath = args[3];
        CudaEnvironment.getInstance().getConfiguration()
                        // 是否允许多卡
                        .allowMultiGPU(false)
                        .useDevice(7)
                        // 显存大小
                        .setMaximumDeviceCache(11L * 1024L * 1024L * 1024L)
                        // 是否允许多卡直接数据的直接访问
                        .allowCrossDeviceAccess(false);
        DataSetIterator trainData = getData(dataPath+ "/", true, batchSize, false);
        DataSetIterator testData = getData(dataPath+ "/" , false, batchSize, false);
        MultiLayerNetwork model = getModel();
        for( int i = 0; i < numEpochs; ++i ){
            model.fit(trainData);
            System.out.println("Epoch :" + i + " Finish");
            System.out.println("Score: " + model.score());
            Evaluation eval = model.evaluate(testData);
            System.out.println(eval.stats());
            System.out.println();  
        }
        Evaluation eval = model.evaluate(testData);
        System.out.println(eval.stats());
        ModelSerializer.writeModel(model, modelSavePath, true);
    }
其中batchSize等可以通过args参数传入设置。注意,最后我们把模型进行了保存。在每一轮的训练后,我们都打印了损失函数的值,并同时在测试集上评估了此时模型的准确性。我们一共训练了100轮。下面给出部分训练过程中的模型信息:

Epoch :0 Finish
Score: 0.4417036887606827

==========================Scores========================================
 Accuracy:        0.7986
 Precision:       0.8004
 Recall:          0.7986
 F1 Score:        0.7995
========================================================================
第一轮的loss值和模型评估。可以说,效果不佳。和Mnist的第一轮相差甚远(下面会有Mnist的相应训练信息)。

Epoch :99 Finish
Score: 0.020469321075697797

==========================Scores========================================
 Accuracy:        0.9072
 Precision:       0.9088
 Recall:          0.9072
 F1 Score:        0.908
========================================================================
100轮训练完之后,勉强达到了90%左右。应该说,结果一般。

到这里,我就没有再训练下去了。那么到此,第一部分的主要工作就完成了。最终经过100轮的训练,loss值达到0.02,模型的准确率在90%。

接着介绍下第二部分,也就是和Mnist比较的内容。

Mnist的训练过程和上面的一模一样,唯一不同的是,数据集换成Mnist的就可以了。同样经过100轮的训练,我们来看下对比结果。

  Mnist DataSet Fashion Mnist DataSet
Epoch 1
==========================Scores====================================
 Accuracy:        0.9545
 Precision:       0.955
 Recall:          0.954
 F1 Score:        0.9545
====================================================================
==========================Scores===================================
 Accuracy:        0.7986
 Precision:       0.8004
 Recall:          0.7986
 F1 Score:        0.7995
===================================================================
Epoch 100
==========================Scores====================================
 Accuracy:        0.9922
 Precision:       0.9921
 Recall:          0.9921
 F1 Score:        0.9921
====================================================================
==========================Scores=====================================
 Accuracy:        0.9072
 Precision:       0.9088
 Recall:          0.9072
 F1 Score:        0.908
=====================================================================
从表格里就可以直观的看出两个数据集在同样的模型、超参数配置下,最终评估效果的不同了。

Mnist数据集很容易就达到了95%的准确率,甚至最后达到了99.22%。然而Fashion Mnist最终也只有徘徊在90%上下。由此可见,Fashion Mnist数据集的分类问题更为复杂。2层卷积神经网络的效果可能也就是在90%左右了(PS:这个讲述并没有什么理论依据,但从github主页看到他人用Keras搭建类似结构的网络来训练Fashion Mnist,也是在90%上下,所以作此推测,仅仅是实验结果)。

最后一个部分介绍下基于刚才训练的模型如何搭建一个Web应用。

服装的分类场景在各大电商企业中有很多应用。虽然不一定需要准确区分运动鞋和休闲鞋,但是区分衣服、裤子、包、鞋还是很有必要的。这个场景在图像检索等应用方面有着类似文本检索中Query分析的作用,最终可以减少索引的查询量。这里就直接利用这样的一个开源数据集搭建一个Web服务,用于识别图片中物品的所属品类。涉及到的工具有Spring、Tomcat,JSP,还有之前提到的Deeplearning4j和Nd4j。

我在本地的Eclipse中配置了Tomcat的插件、服务的端口号、上下文的根路径等。在POM文件中引入了Spring和Deeplearning4j的相关依赖。最后前端页面上做了个简单的上传图片的按钮,最后的模型分类结果会和图片一起在页面上做展示。由于这里面涉及了关于J2EE开发的诸多细节,和主要介绍的内容有些偏离,所以这里仅仅介绍主要的思路。在后面的文章中,如果有机会的话会详细介绍Deeplearning4j训练的模型上线部署的一些方式,当然也包括一些采坑的地方。下面就给出一些示例结果:

这些服装类的图片是从苏宁易购的网站上面下载下来的,而且都是一些不需要做主体检测的、内容比较明确的图片。从实际的效果来看,确实可以对这些图片的品类做相对准确的识别。不过,其中也有误判的场景,比如长袖衬衫那个场景被预测成了外套。当然这只是一个demo,并不是最终可以达到产品效果的服务,而且在实际的应用中,像衬衫和外套一般并不会要求严格区分,毕竟单纯靠一张正面的图片就区分两个非常相似的品类是非常困难的,虽然并非一定不可以做到,但准确率未必可以保证。

这里有个地方需要注意:Fashion Mnist是28*28的灰度图。在做这个实际应用场景的时候,我同样对这些彩色图片做了灰度化以及resize的处理。换句话说,和训练数据保持一致对预测结果也同样重要。

最后对上面的内容做下小结。Fashion Mnist作为Mnist的替代数据集,无论在数据格式还是文件名称上都和原始的Mnist保持了高度一致,从而方便研发人员迁移之前的工作。但是,Fashion Mnist的分类比Mnist更有挑战性,至少从目前github主页上最优结果以及我自己的实验来看,很难达到和Mnist一样的准确性。原因的话,像外套、衬衫;靴子、运动鞋;难免存在外形极其相似的情况。因此,误判的情况会比较多。不过,从另一个角度说,这也说明相比Mnist,Fashion Mnist数据集的分类问题更为复杂。此外,Fashion Mnist也可以作为诸多电商企业商品图片分类的一个demo级别的测试数据集。通过做服装商品分类这样一个应用,可以对深度学习在产品级别应用的问题上有感性的认识,更重要的可能是发现深度神经网络的局限性,并非是万能的。这可能也是Fashion Mnist相比于原始Mnist数据集的价值所在,让大家对深度学习有理性的认识(原始Mnist很容易达到98%-99%的准确率,容易误导大家觉得深度学习就是这样准确,其实数据集本身也有非常大的关系,不能仅仅依靠模型)。

--------------以下更新自2018/3/21

在Deeplearning4j的QQ群里还有这篇文章的留言区有同学希望我补充下Web部分的代码,特此在这里做些补充。

首先,Web容器我用的是Tomcat-Eclipse的插件,在pom里的配置如下:

  <build> 
    <finalName>dl-webapp</finalName>  
    <plugins> 
      <plugin> 
        <groupId>org.apache.tomcat.maven</groupId>  
        <artifactId>tomcat7-maven-plugin</artifactId>  
        <version>2.2</version>  
        <configuration> 
          <port>8080</port>  
          <path>/maven-web-demo</path>  
          <uriEncoding>UTF-8</uriEncoding>  
          <finalName>maven-web-demo</finalName>  
          <server>tomcat7</server> 
        </configuration>  
        <executions> 
          <!-- 打包成功后即开始运行web容器 -->  
          <execution> 
            <phase>package</phase>  
            <goals> 
              <goal>run</goal> 
            </goals> 
          </execution> 
        </executions> 
      </plugin> 
    </plugins> 
  </build>

其次,整个Web工程的编译目录结构如下(工程名:DL):


猜你喜欢

转载自blog.csdn.net/wangongxi/article/details/78475150