deeplearning4j training inference case 2023 - handwritten digit recognition

1.minist data set

Download link 6W training set, 1W test set

2. Dependency package

Mainly some packages of deeplearning4j and javacv. The jar package produced in the case is 1.3G. The pom comes from the github deeplearning sub-projectdeeplearning4j-examples dl4j-examples module

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>
    <parent>
        <groupId>org.springframework.boot</groupId>
        <artifactId>spring-boot-starter-parent</artifactId>
        <version>2.7.9</version>
        <relativePath/>

    </parent>
    <groupId>com.example</groupId>
    <artifactId>demo</artifactId>
    <version>0.0.1-SNAPSHOT</version>
    <name>demo</name>
    <description>demo</description>
    <properties>
        <dl4j-master.version>1.0.0-M2.1</dl4j-master.version>
        <nd4j.backend>nd4j-native</nd4j.backend>
        <java.version>17</java.version>
        <maven-compiler-plugin.version>3.8.1</maven-compiler-plugin.version>
        <maven.minimum.version>3.3.1</maven.minimum.version>
        <exec-maven-plugin.version>1.4.0</exec-maven-plugin.version>
        <maven-shade-plugin.version>2.4.3</maven-shade-plugin.version>
        <jcommon.version>1.0.23</jcommon.version>
        <jfreechart.version>1.0.13</jfreechart.version>
        <logback.version>1.1.7</logback.version>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <junit.version>5.8.0-M1</junit.version>
        <javacv.version>1.5.9</javacv.version>
    </properties>
    <dependencyManagement>
        <dependencies>
            <dependency>
                <groupId>org.bytedeco</groupId>
                <artifactId>javacv-platform</artifactId>
                <version>${javacv.version}</version>
            </dependency>
        </dependencies>
    </dependencyManagement>
    <dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter</artifactId>
        </dependency>
        <dependency>
            <groupId>org.projectlombok</groupId>
            <artifactId>lombok</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>${nd4j.backend}</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>


        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-api</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-data-image</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.datavec</groupId>
            <artifactId>datavec-local</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-datasets</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-core</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>resources</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>

        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-ui</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-zoo</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <!-- ParallelWrapper & ParallelInference live here -->
        <dependency>
            <groupId>org.deeplearning4j</groupId>
            <artifactId>deeplearning4j-parallel-wrapper</artifactId>
            <version>${dl4j-master.version}</version>
        </dependency>
        <!-- Used in the feedforward/classification/MLP* and feedforward/regression/RegressionMathFunctions example -->
        <dependency>
            <groupId>jfree</groupId>
            <artifactId>jfreechart</artifactId>
            <version>${jfreechart.version}</version>
        </dependency>
        <dependency>
            <groupId>org.jfree</groupId>
            <artifactId>jcommon</artifactId>
            <version>${jcommon.version}</version>
        </dependency>
        <!-- Used for downloading data in some of the examples -->
        <dependency>
            <groupId>org.apache.httpcomponents</groupId>
            <artifactId>httpclient</artifactId>
            <version>4.3.5</version>
        </dependency>
        <dependency>
            <groupId>ch.qos.logback</groupId>
            <artifactId>logback-classic</artifactId>
            <version>${logback.version}</version>
        </dependency>

        <dependency>
            <groupId>org.bytedeco</groupId>
            <artifactId>javacv-platform</artifactId>
        </dependency>
        <dependency>
            <groupId>org.nd4j</groupId>
            <artifactId>nd4j-api</artifactId>
            <version>1.0.0-M2.1</version>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.springframework.boot</groupId>
                <artifactId>spring-boot-maven-plugin</artifactId>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>17</source>
                    <target>17</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

3. Handwritten digit training and reasoning

One epoch training takes 100 seconds, and the accuracy is 97%. Please see the code comments for details. The API of the framework is relatively easy to use.

package ai;

import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.io.labels.ParentPathLabelGenerator;
import org.datavec.api.split.FileSplit;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.api.OptimizationAlgorithm;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.model.stats.StatsListener;
import org.deeplearning4j.ui.model.storage.FileStatsStorage;
import org.deeplearning4j.ui.model.storage.InMemoryStatsStorage;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.common.io.Assert;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.ImagePreProcessingScaler;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.nd4j.linalg.schedule.MapSchedule;
import org.nd4j.linalg.schedule.ScheduleType;

import java.io.File;
import java.util.Random;

@Slf4j
public class LeNetMNISTReLu {
    
    
    private static final String DATASET_PATH_BASE = "D:\\";

    public static void main(String[] args) throws Exception {
    
    
        int height = 28;
        int width = 28;
        // 黑白图片通道只有一个
        int channels = 1;
        // 0-9十种数字
        int outputNum = 10;
        int batchSize = 64;
        // 这里一个epoch耗时约100s,3次准确率99%
        int nEpochs = 1;


        Assert.isTrue(new File(DATASET_PATH_BASE + "/mnist_png").exists(), "请下载压缩包并解压到" + DATASET_PATH_BASE);
        // 该label生成器会将数据所在父目录名作为label,要求目录名必须为数值,这里mnist数据集正好是放在0-9文件夹的
        ParentPathLabelGenerator labelMaker = new ParentPathLabelGenerator();
        // 归一化(0-1)
        DataNormalization normalization = new ImagePreProcessingScaler();
        Random random = new Random(12345);
        log.info("训练集6W张...");
        File trainData = new File(DATASET_PATH_BASE + "/mnist_png/training");
        FileSplit trainSplit = new FileSplit(trainData, NativeImageLoader.ALLOWED_FORMATS, random);
        ImageRecordReader trainRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
        trainRecordReader.initialize(trainSplit);
        DataSetIterator trainIter = new RecordReaderDataSetIterator(trainRecordReader, batchSize, 1, outputNum);
        normalization.fit(trainIter);
        trainIter.setPreProcessor(normalization); // 先像素归一化

        log.info("验证集1W张...");
        File validateData = new File(DATASET_PATH_BASE + "/mnist_png/testing");
        FileSplit validateSplit = new FileSplit(validateData, NativeImageLoader.ALLOWED_FORMATS, random);
        ImageRecordReader validateRecordReader = new ImageRecordReader(height, width, channels, labelMaker);
        validateRecordReader.initialize(validateSplit);
        DataSetIterator validateIter = new RecordReaderDataSetIterator(validateRecordReader, batchSize, 1, outputNum);
        validateIter.setPreProcessor(normalization);

        // 训练集6W数据 每次迭代batchSize=64,故这里大概有1000次迭代
        // 学习率,每200个迭代更新一次学习率(步长),先大一点,还可以每个Epoch更新一次学习率
        MapSchedule mapSchedule = new MapSchedule.Builder(ScheduleType.ITERATION)
                .add(0, 0.06)
                .add(200, 0.05)
                .add(600, 0.028)
                .add(800, 0.006)
                .add(1000, 0.001)
                .build();

        // 超参
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(1)
                .l2(0.0005)
                .updater(new Nesterovs(mapSchedule))
                //.optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT) //该优化器导致长时间无法拟合
                .weightInit(WeightInit.XAVIER)
                .list()
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .nIn(channels)
                        .stride(1, 1)
                        .nOut(20)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new ConvolutionLayer.Builder(5, 5)
                        .stride(1, 1)
                        .nOut(50)
                        .activation(Activation.IDENTITY)
                        .build())
                .layer(new SubsamplingLayer.Builder(SubsamplingLayer.PoolingType.MAX)
                        .kernelSize(2, 2)
                        .stride(2, 2)
                        .build())
                .layer(new DenseLayer.Builder().activation(Activation.RELU)
                        .nOut(500)
                        .build())
                .layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .build())
                .setInputType(InputType.convolutionalFlat(height, width, channels)) // InputType.convolutional for normal image
                .build();

        // 神经网络对象构建
        MultiLayerNetwork net = new MultiLayerNetwork(conf);
        net.init();
        // 训练监控,每次迭代打印损失函数值
        net.setListeners(new ScoreIterationListener(10));
        // WEB UI监控训练过程
        //UIServer uiServer = UIServer.getInstance();
        //FileStatsStorage statsStorage = new FileStatsStorage(new File("D:\\ai-webui.dat"));
        //uiServer.attach(statsStorage);
        //net.setListeners(new StatsListener(statsStorage));
        log.info("网络参数个数{}", net.numParams());
        long startTime = System.currentTimeMillis();
        // 训练epochs轮
        for (int i = 0; i < nEpochs; i++) {
    
    
            log.info("Epoch=" + i);
            net.fit(trainIter);
            Evaluation eval = net.evaluate(validateIter);
            log.info(eval.stats());
            trainIter.reset();
            validateIter.reset();
        }
        log.info("训练耗时{}毫秒", System.currentTimeMillis() - startTime);
        // 保存模型
        File ministModelPath = new File(DATASET_PATH_BASE + "/ministModel.zip");
        ModelSerializer.writeModel(net, ministModelPath, true);
        // 推理逻辑:加载网络(模型)——>加载测试图片——>预测
        MultiLayerNetwork network = ModelSerializer.restoreMultiLayerNetwork(new File(DATASET_PATH_BASE + "/ministModel.zip"));
        NativeImageLoader imageLoader = new NativeImageLoader(height, width, channels);
        FileUtils.listFiles(new File("D:\\mnist_png\\testing"), null, true)
                .parallelStream().forEach(file -> {
    
    
                    try {
    
    
                        INDArray matrix = imageLoader.asMatrix(file);
                        INDArray output = network.output(matrix);
                        // 取最可能的预测结果
                        int predictedValue = Nd4j.argMax(output, 1).getInt(0);
                        // 数字图片按数值放在每个文件夹的,故图片所在文件夹名即为真实值
                        String realValue = file.getParentFile().getName();
                        log.info("真实值:{},预测值:{}", realValue, predictedValue);
                        Assert.isTrue(predictedValue == Integer.parseInt(realValue), file.getAbsolutePath() + "预测错误");
                    } catch (Exception e) {
    
    
                        log.warn(e.getMessage(), e);
                    }
                });
    }
}

4. Extended reading deeplearning4j comes with learning case project deeplearning4j-examples

deeplearning4j-examples Refer to its readme document and pick a few cases to run. First, download the four tool classes of util in dl4j-examples
①Classic iris classification case, just run it
IrisClassifier.java

/* *****************************************************************************
 *
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package ai;

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.SplitTestAndTrain;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.learning.config.Sgd;
import org.nd4j.linalg.lossfunctions.LossFunctions;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.File;

/**
 * @author Adam Gibson
 */
@SuppressWarnings("DuplicatedCode")
public class IrisClassifier {
    
    

    private static Logger log = LoggerFactory.getLogger(IrisClassifier.class);

    public static void main(String[] args) throws  Exception {
    
    

        //First: get the dataset using the record reader. CSVRecordReader handles loading/parsing
        int numLinesToSkip = 0;
        char delimiter = ',';
        RecordReader recordReader = new CSVRecordReader(numLinesToSkip,delimiter);
        recordReader.initialize(new FileSplit(new File(DownloaderUtility.IRISDATA.Download(),"iris.txt")));

        //Second: the RecordReaderDataSetIterator handles conversion to DataSet objects, ready for use in neural network
        int labelIndex = 4;     //5 values in each row of the iris.txt CSV: 4 input features followed by an integer label (class) index. Labels are the 5th value (index 4) in each row
        int numClasses = 3;     //3 classes (types of iris flowers) in the iris data set. Classes have integer values 0, 1 or 2
        int batchSize = 150;    //Iris data set: 150 examples total. We are loading all of them into one DataSet (not recommended for large data sets)

        DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader,batchSize,labelIndex,numClasses);
        DataSet allData = iterator.next();
        allData.shuffle();
        SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);  //Use 65% of data for training

        DataSet trainingData = testAndTrain.getTrain();
        DataSet testData = testAndTrain.getTest();

        //We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
        DataNormalization normalizer = new NormalizerStandardize();
        normalizer.fit(trainingData);           //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
        normalizer.transform(trainingData);     //Apply normalization to the training data
        normalizer.transform(testData);         //Apply normalization to the test data. This is using statistics calculated from the *training* set


        final int numInputs = 4;
        int outputNum = 3;
        long seed = 6;


        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(seed)
                .activation(Activation.TANH)
                .weightInit(WeightInit.XAVIER)
                .updater(new Sgd(0.1))
                .l2(1e-4)
                .list()
                .layer(new DenseLayer.Builder().nIn(numInputs).nOut(3)
                        .build())
                .layer(new DenseLayer.Builder().nIn(3).nOut(3)
                        .build())
                .layer( new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
                        .activation(Activation.SOFTMAX) //Override the global TANH activation with softmax for this layer
                        .nIn(3).nOut(outputNum).build())
                .build();

        //run the model
        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //record score once every 100 iterations
        model.setListeners(new ScoreIterationListener(100));

        for(int i=0; i<1000; i++ ) {
    
    
            model.fit(trainingData);
        }

        //evaluate the model on the test set
        Evaluation eval = new Evaluation(3);
        INDArray output = model.output(testData.getFeatures());
        eval.eval(testData.getLabels(), output);
        log.info(eval.stats());

    }

}

② A simpler MNIST classification
MNISTSingleLayer.java, be careful to change the batchSize and epoch to a smaller size, otherwise it will be more time-consuming to run Xiaobawang

/* *****************************************************************************
 *
 *
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *  See the NOTICE file distributed with this work for additional
 *  information regarding copyright ownership.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package ai;

import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.learning.config.Nesterovs;
import org.nd4j.linalg.lossfunctions.LossFunctions.LossFunction;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**A Simple Multi Layered Perceptron (MLP) applied to digit classification for
 * the MNIST Dataset (http://yann.lecun.com/exdb/mnist/).
 *
 * This file builds one input layer and one hidden layer.
 *
 * The input layer has input dimension of numRows*numColumns where these variables indicate the
 * number of vertical and horizontal pixels in the image. This layer uses a rectified linear unit
 * (relu) activation function. The weights for this layer are initialized by using Xavier initialization
 * (https://prateekvjoshi.com/2016/03/29/understanding-xavier-initialization-in-deep-neural-networks/)
 * to avoid having a steep learning curve. This layer will have 1000 output signals to the hidden layer.
 *
 * The hidden layer has input dimensions of 1000. These are fed from the input layer. The weights
 * for this layer is also initialized using Xavier initialization. The activation function for this
 * layer is a softmax, which normalizes all the 10 outputs such that the normalized sums
 * add up to 1. The highest of these normalized values is picked as the predicted class.
 *
 */
public class MNISTSingleLayer {
    
    

    private static Logger log = LoggerFactory.getLogger(MNISTSingleLayer.class);

    public static void main(String[] args) throws Exception {
    
    
        //number of rows and columns in the input pictures
        final int numRows = 28;
        final int numColumns = 28;
        int outputNum = 10; // number of output classes
        int batchSize = 64; // batch size for each epoch
        int rngSeed = 123; // random number seed for reproducibility
        int numEpochs = 1; // number of epochs to perform

        //Get the DataSetIterators:
        DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, rngSeed);
        DataSetIterator mnistTest = new MnistDataSetIterator(batchSize, false, rngSeed);


        log.info("Build model....");
        MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
                .seed(rngSeed) //include a random seed for reproducibility
                // use stochastic gradient descent as an optimization algorithm
                .updater(new Nesterovs(0.006, 0.9))
                .l2(1e-4)
                .list()
                .layer(new DenseLayer.Builder() //create the first, input layer with xavier initialization
                        .nIn(numRows * numColumns)
                        .nOut(1000)
                        .activation(Activation.RELU)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .layer(new OutputLayer.Builder(LossFunction.NEGATIVELOGLIKELIHOOD) //create hidden layer
                        .nIn(1000)
                        .nOut(outputNum)
                        .activation(Activation.SOFTMAX)
                        .weightInit(WeightInit.XAVIER)
                        .build())
                .build();

        MultiLayerNetwork model = new MultiLayerNetwork(conf);
        model.init();
        //print the score with every 1 iteration
        model.setListeners(new ScoreIterationListener(1));

        log.info("Train model....");
        model.fit(mnistTrain, numEpochs);


        log.info("Evaluate model....");
        Evaluation eval = model.evaluate(mnistTest);
        log.info(eval.stats());
        log.info("****************Example finished********************");

    }

}

③Use RNN/CNN model to classify emotions on IMDB data set

Just download the three java files under github code pathdeeplearning4j-examples/dl4j-examples/src/main/java/org/deeplearning4j/examples/advanced/modelling/textclassification/pretrainedword2vec.

  • Download an imdb data set. It is recommended to download it manually and put it in the ImdbReviewClassificationRNN.java DATA_PATH folder
  • After manually downloading the word vector file, assign the path towordVectorsPath变量. The 16G Xiaobawang loaded the word vector and the memory overflowed

Guess you like

Origin blog.csdn.net/qq_39506978/article/details/134024627